Alfred Liu commited on
Commit ·
d19bd3e
1
Parent(s): b0800d3
Code release
Browse files- .gitignore +53 -0
- README.md +131 -1
- configs/r101_nuimg_1408x512_900q_24ep.py +96 -0
- configs/r50_in1k_704x256_900q_36ep.py +21 -0
- configs/r50_nuimg_704x256_400q_36ep.py +236 -0
- gen_sweep_info.py +112 -0
- loaders/__init__.py +6 -0
- loaders/builder.py +49 -0
- loaders/nuscenes_dataset.py +88 -0
- loaders/pipelines/__init__.py +7 -0
- loaders/pipelines/loading.py +154 -0
- loaders/pipelines/transforms.py +394 -0
- models/__init__.py +9 -0
- models/backbones/__init__.py +3 -0
- models/backbones/vovnet.py +383 -0
- models/bbox/__init__.py +3 -0
- models/bbox/assigners/__init__.py +3 -0
- models/bbox/assigners/hungarian_assigner_3d.py +93 -0
- models/bbox/coders/__init__.py +3 -0
- models/bbox/coders/nms_free_coder.py +110 -0
- models/bbox/match_costs/__init__.py +3 -0
- models/bbox/match_costs/match_cost.py +53 -0
- models/bbox/utils.py +77 -0
- models/checkpoint.py +444 -0
- models/csrc/__init__.py +0 -0
- models/csrc/msmv_sampling/msmv_sampling.cpp +369 -0
- models/csrc/msmv_sampling/msmv_sampling.h +43 -0
- models/csrc/msmv_sampling/msmv_sampling_backward.cu +448 -0
- models/csrc/msmv_sampling/msmv_sampling_forward.cu +333 -0
- models/csrc/setup.py +24 -0
- models/csrc/wrapper.py +87 -0
- models/sparsebev.py +322 -0
- models/sparsebev_head.py +469 -0
- models/sparsebev_sampling.py +102 -0
- models/sparsebev_transformer.py +370 -0
- models/utils.py +309 -0
- timing.py +100 -0
- train.py +180 -0
- utils.py +188 -0
- val.py +137 -0
.gitignore
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OS generated files
|
| 2 |
+
.DS_Store
|
| 3 |
+
.DS_Store?
|
| 4 |
+
._*
|
| 5 |
+
.Spotlight-V100
|
| 6 |
+
.Trashes
|
| 7 |
+
ehthumbs.db
|
| 8 |
+
Thumbs.db
|
| 9 |
+
|
| 10 |
+
# Compiled source
|
| 11 |
+
build
|
| 12 |
+
debug
|
| 13 |
+
Debug
|
| 14 |
+
release
|
| 15 |
+
Release
|
| 16 |
+
x64
|
| 17 |
+
*.so
|
| 18 |
+
*.whl
|
| 19 |
+
|
| 20 |
+
# VS project files
|
| 21 |
+
*.sln
|
| 22 |
+
*.vcxproj
|
| 23 |
+
*.vcxproj.filters
|
| 24 |
+
*.vcxproj.user
|
| 25 |
+
*.rc
|
| 26 |
+
.vs
|
| 27 |
+
|
| 28 |
+
# Byte-compiled / optimized / DLL files
|
| 29 |
+
*__pycache__*
|
| 30 |
+
*.py[cod]
|
| 31 |
+
*$py.class
|
| 32 |
+
|
| 33 |
+
# Distribution / packaging
|
| 34 |
+
.Python
|
| 35 |
+
build
|
| 36 |
+
develop-eggs
|
| 37 |
+
dist
|
| 38 |
+
downloads
|
| 39 |
+
|
| 40 |
+
# IDE
|
| 41 |
+
.idea
|
| 42 |
+
.vscode
|
| 43 |
+
pyrightconfig.json
|
| 44 |
+
|
| 45 |
+
# Custom
|
| 46 |
+
data
|
| 47 |
+
outputs
|
| 48 |
+
prediction
|
| 49 |
+
submission
|
| 50 |
+
checkpoints
|
| 51 |
+
pretrain
|
| 52 |
+
*.png
|
| 53 |
+
*.jpg
|
README.md
CHANGED
|
@@ -1,2 +1,132 @@
|
|
| 1 |
# SparseBEV
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# SparseBEV
|
| 2 |
+
|
| 3 |
+
This is the official PyTorch implementation for paper [SparseBEV: High-Performance Sparse 3D Object Detection from Multi-Camera Videos](https://arxiv.org/abs/2308.09244). (ICCV 2023)
|
| 4 |
+
|
| 5 |
+
## Model Zoo
|
| 6 |
+
|
| 7 |
+
| Backbone | Pretrain | Input Size | Epochs | Training Cost | NDS | FPS | Config | Weights |
|
| 8 |
+
|----------|----------|------------|--------|---------------|-----|-----|--------|---------|
|
| 9 |
+
| R50 | [nuImages](https://github.com/open-mmlab/mmdetection3d/blob/main/configs/nuimages/cascade-mask-rcnn_r50_fpn_coco-20e_20e_nuim.py) | 704x256 | 36 | 28h (8x2080Ti) | 55.8 | 23.5 | [config](configs/r50_nuimg_704x256_400q_36ep.py) | [weights](https://drive.google.com/file/d/1C_Vn3iiSnSW1Dw1r0DkjJMwvHC5Y3zTN/view?usp=sharing) |
|
| 10 |
+
|
| 11 |
+
* FPS is measured on a machine with AMD 5800X and RTX 3090.
|
| 12 |
+
* The noise is around 0.3 NDS.
|
| 13 |
+
|
| 14 |
+
## Environment
|
| 15 |
+
|
| 16 |
+
Install PyTorch 2.0 + CUDA 11.8:
|
| 17 |
+
|
| 18 |
+
```
|
| 19 |
+
conda create -n sparsebev python=3.8
|
| 20 |
+
conda activate sparsebev
|
| 21 |
+
conda install pytorch==2.0.0 torchvision==0.15.0 pytorch-cuda=11.8 -c pytorch -c nvidia
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
or PyTorch 1.10.2 + CUDA 10.2 for older GPUs:
|
| 25 |
+
|
| 26 |
+
```
|
| 27 |
+
conda create -n sparsebev python=3.8
|
| 28 |
+
conda activate sparsebev
|
| 29 |
+
conda install pytorch==1.10.2 torchvision==0.11.3 cudatoolkit=10.2 -c pytorch
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
Install other dependencies:
|
| 33 |
+
|
| 34 |
+
```
|
| 35 |
+
pip install openmim
|
| 36 |
+
mim install mmcv-full==1.6.0
|
| 37 |
+
mim install mmdet==2.28.2
|
| 38 |
+
mim install mmsegmentation==0.30.0
|
| 39 |
+
mim install mmdet3d==1.0.0rc6
|
| 40 |
+
pip install setuptools==59.5.0
|
| 41 |
+
pip install numpy==1.23.5
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
Install turbojpeg and pillow-simd to speed up data loading (optional but important):
|
| 45 |
+
|
| 46 |
+
```
|
| 47 |
+
sudo apt-get update
|
| 48 |
+
sudo apt-get install -y libturbojpeg
|
| 49 |
+
pip install pyturbojpeg
|
| 50 |
+
pip uninstall pillow
|
| 51 |
+
pip install pillow-simd==9.0.0.post1
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
Compile CUDA extensions:
|
| 55 |
+
|
| 56 |
+
```
|
| 57 |
+
cd models/csrc
|
| 58 |
+
python setup.py build_ext --inplace
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
## Prepare Dataset
|
| 62 |
+
|
| 63 |
+
1. Download nuScenes from [https://www.nuscenes.org/nuscenes](https://www.nuscenes.org/nuscenes) and put it in `data/nuscenes`.
|
| 64 |
+
2. Download the generated info file from [Google Drive](https://drive.google.com/file/d/1uyoUuSRIVScrm_CUpge6V_UzwDT61ODO/view?usp=sharing) and unzip it.
|
| 65 |
+
3. Folder structure:
|
| 66 |
+
|
| 67 |
+
```
|
| 68 |
+
data/nuscenes
|
| 69 |
+
├── maps
|
| 70 |
+
├── nuscenes_infos_test_sweep.pkl
|
| 71 |
+
├── nuscenes_infos_train_mini_sweep.pkl
|
| 72 |
+
├── nuscenes_infos_train_sweep.pkl
|
| 73 |
+
├── nuscenes_infos_val_mini_sweep.pkl
|
| 74 |
+
├── nuscenes_infos_val_sweep.pkl
|
| 75 |
+
├── samples
|
| 76 |
+
├── sweeps
|
| 77 |
+
├── v1.0-mini
|
| 78 |
+
├── v1.0-test
|
| 79 |
+
└── v1.0-trainval
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
These `*.pkl` files can also be generated with our script: `gen_sweep_info.py`.
|
| 83 |
+
|
| 84 |
+
## Training
|
| 85 |
+
|
| 86 |
+
Train SparseBEV with 8 GPUs:
|
| 87 |
+
|
| 88 |
+
```
|
| 89 |
+
torchrun --nproc_per_node 8 train.py --config configs/r50_nuimg_704x256_400q_36ep.py
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
Train SparseBEV with 4 GPUs (i.e the last four GPUs):
|
| 93 |
+
|
| 94 |
+
```
|
| 95 |
+
export CUDA_VISIBLE_DEVICES=4,5,6,7
|
| 96 |
+
torchrun --nproc_per_node 4 train.py --config configs/r50_nuimg_704x256_400q_36ep.py
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
The batch size for each GPU will be scaled automatically. So there is no need to modify the `batch_size` in config files.
|
| 100 |
+
|
| 101 |
+
## Evaluation
|
| 102 |
+
|
| 103 |
+
Single-GPU evaluation:
|
| 104 |
+
|
| 105 |
+
```
|
| 106 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 107 |
+
python val.py --config configs/r50_nuimg_704x256_400q_36ep.py --weights checkpoints/r50_nuimg_704x256_400q_36ep.pth
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
Multi-GPU evaluation:
|
| 111 |
+
|
| 112 |
+
```
|
| 113 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
| 114 |
+
torchrun --nproc_per_node 8 val.py --config configs/r50_nuimg_704x256_400q_36ep.py --weights checkpoints/r50_nuimg_704x256_400q_36ep.pth
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
## Timing
|
| 118 |
+
|
| 119 |
+
FPS is measured with a single GPU:
|
| 120 |
+
|
| 121 |
+
```
|
| 122 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 123 |
+
python timing.py --config configs/r50_nuimg_704x256_400q_36ep.py --weights checkpoints/r50_nuimg_704x256_400q_36ep.pth
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
## Acknowledgements
|
| 127 |
+
|
| 128 |
+
Many thanks to these excellent open-source projects:
|
| 129 |
+
|
| 130 |
+
* 3D Detection: [DETR3D](https://github.com/WangYueFt/detr3d), [PETR](https://github.com/megvii-research/PETR), [BEVFormer](https://github.com/fundamentalvision/BEVFormer), [BEVDet](https://github.com/HuangJunJie2017/BEVDet), [StreamPETR](https://github.com/exiawsh/StreamPETR)
|
| 131 |
+
* 2D Detection: [AdaMixer](https://github.com/MCG-NJU/AdaMixer), [DN-DETR](https://github.com/IDEA-Research/DN-DETR)
|
| 132 |
+
* Codebase: [MMDetection3D](https://github.com/open-mmlab/mmdetection3d), [CamLiFlow](https://github.com/MCG-NJU/CamLiFlow)
|
configs/r101_nuimg_1408x512_900q_24ep.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_base_ = ['./r50_nuimg_704x256_400q_36ep.py']
|
| 2 |
+
|
| 3 |
+
# For nuScenes we usually do 10-class detection
|
| 4 |
+
class_names = [
|
| 5 |
+
'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
|
| 6 |
+
'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
|
| 7 |
+
]
|
| 8 |
+
|
| 9 |
+
# If point cloud range is changed, the models should also change their point
|
| 10 |
+
# cloud range accordingly
|
| 11 |
+
point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
|
| 12 |
+
voxel_size = [0.2, 0.2, 8]
|
| 13 |
+
|
| 14 |
+
img_backbone = dict(
|
| 15 |
+
type='ResNet',
|
| 16 |
+
depth=101,
|
| 17 |
+
with_cp=True,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
img_neck = dict(
|
| 21 |
+
type='FPN',
|
| 22 |
+
in_channels=[256, 512, 1024, 2048],
|
| 23 |
+
out_channels=256,
|
| 24 |
+
num_outs=5,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
model = dict(
|
| 28 |
+
img_backbone=img_backbone,
|
| 29 |
+
img_neck=img_neck,
|
| 30 |
+
pts_bbox_head=dict(
|
| 31 |
+
num_query=900,
|
| 32 |
+
transformer=dict(num_levels=5)),
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
ida_aug_conf = {
|
| 36 |
+
'resize_lim': (0.38 * 2, 0.55 * 2),
|
| 37 |
+
'final_dim': (512, 1408),
|
| 38 |
+
'bot_pct_lim': (0.0, 0.0),
|
| 39 |
+
'rot_lim': (0.0, 0.0),
|
| 40 |
+
'H': 900, 'W': 1600,
|
| 41 |
+
'rand_flip': True,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
train_pipeline = [
|
| 45 |
+
dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),
|
| 46 |
+
dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=7),
|
| 47 |
+
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True, with_attr_label=False),
|
| 48 |
+
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
|
| 49 |
+
dict(type='ObjectNameFilter', classes=class_names),
|
| 50 |
+
dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=True),
|
| 51 |
+
dict(type='GlobalRotScaleTransImage', rot_range=[-0.3925, 0.3925], scale_ratio_range=[0.95, 1.05]),
|
| 52 |
+
dict(type='DefaultFormatBundle3D', class_names=class_names),
|
| 53 |
+
dict(type='Collect3D', keys=['gt_bboxes_3d', 'gt_labels_3d', 'img'], meta_keys=(
|
| 54 |
+
'filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp'))
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
test_pipeline = [
|
| 58 |
+
dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),
|
| 59 |
+
dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=7, test_mode=True),
|
| 60 |
+
dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=False),
|
| 61 |
+
dict(
|
| 62 |
+
type='MultiScaleFlipAug3D',
|
| 63 |
+
img_scale=(1600, 900),
|
| 64 |
+
pts_scale_ratio=1,
|
| 65 |
+
flip=False,
|
| 66 |
+
transforms=[
|
| 67 |
+
dict(type='DefaultFormatBundle3D', class_names=class_names, with_label=False),
|
| 68 |
+
dict(type='Collect3D', keys=['img'], meta_keys=(
|
| 69 |
+
'filename', 'box_type_3d', 'ori_shape', 'img_shape', 'pad_shape',
|
| 70 |
+
'lidar2img', 'img_timestamp'))
|
| 71 |
+
])
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
data = dict(
|
| 75 |
+
workers_per_gpu=4,
|
| 76 |
+
train=dict(pipeline=train_pipeline),
|
| 77 |
+
val=dict(pipeline=test_pipeline),
|
| 78 |
+
test=dict(pipeline=test_pipeline)
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
optimizer = dict(
|
| 82 |
+
type='AdamW',
|
| 83 |
+
lr=2e-4,
|
| 84 |
+
paramwise_cfg=dict(custom_keys={
|
| 85 |
+
'img_backbone': dict(lr_mult=0.2),
|
| 86 |
+
'sampling_offset': dict(lr_mult=0.1),
|
| 87 |
+
}),
|
| 88 |
+
weight_decay=0.01
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# load pretrained weights
|
| 92 |
+
load_from = 'pretrain/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth'
|
| 93 |
+
revise_keys = [('backbone', 'img_backbone')]
|
| 94 |
+
|
| 95 |
+
total_epochs = 24
|
| 96 |
+
eval_config = dict(interval=total_epochs)
|
configs/r50_in1k_704x256_900q_36ep.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_base_ = ['./r50_nuimg_704x256_400q_36ep.py']
|
| 2 |
+
|
| 3 |
+
img_backbone = dict(pretrained='torchvision://resnet50')
|
| 4 |
+
|
| 5 |
+
model = dict(
|
| 6 |
+
img_backbone=img_backbone,
|
| 7 |
+
pts_bbox_head=dict(num_query=900)
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
optimizer = dict(
|
| 11 |
+
paramwise_cfg=dict(custom_keys={
|
| 12 |
+
'img_backbone': dict(lr_mult=0.4),
|
| 13 |
+
'sampling_offset': dict(lr_mult=0.1),
|
| 14 |
+
})
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
load_from = None
|
| 18 |
+
revise_keys = None
|
| 19 |
+
|
| 20 |
+
total_epochs = 36
|
| 21 |
+
eval_config = dict(interval=total_epochs)
|
configs/r50_nuimg_704x256_400q_36ep.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset_type = 'CustomNuScenesDataset'
|
| 2 |
+
dataset_root = 'data/nuscenes/'
|
| 3 |
+
|
| 4 |
+
input_modality = dict(
|
| 5 |
+
use_lidar=False,
|
| 6 |
+
use_camera=True,
|
| 7 |
+
use_radar=False,
|
| 8 |
+
use_map=False,
|
| 9 |
+
use_external=True
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
# For nuScenes we usually do 10-class detection
|
| 13 |
+
class_names = [
|
| 14 |
+
'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
|
| 15 |
+
'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
# If point cloud range is changed, the models should also change their point
|
| 19 |
+
# cloud range accordingly
|
| 20 |
+
point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
|
| 21 |
+
voxel_size = [0.2, 0.2, 8]
|
| 22 |
+
|
| 23 |
+
# arch config
|
| 24 |
+
embed_dims = 256
|
| 25 |
+
num_layers = 6
|
| 26 |
+
num_query = 400
|
| 27 |
+
num_frames = 8
|
| 28 |
+
num_levels = 4
|
| 29 |
+
num_points = 4
|
| 30 |
+
|
| 31 |
+
img_backbone = dict(
|
| 32 |
+
type='ResNet',
|
| 33 |
+
depth=50,
|
| 34 |
+
num_stages=4,
|
| 35 |
+
out_indices=(0, 1, 2, 3),
|
| 36 |
+
frozen_stages=1,
|
| 37 |
+
norm_cfg=dict(type='BN2d', requires_grad=True),
|
| 38 |
+
norm_eval=True,
|
| 39 |
+
style='pytorch',
|
| 40 |
+
with_cp=True)
|
| 41 |
+
img_neck = dict(
|
| 42 |
+
type='FPN',
|
| 43 |
+
in_channels=[256, 512, 1024, 2048],
|
| 44 |
+
out_channels=embed_dims,
|
| 45 |
+
num_outs=num_levels)
|
| 46 |
+
img_norm_cfg = dict(
|
| 47 |
+
mean=[123.675, 116.280, 103.530],
|
| 48 |
+
std=[58.395, 57.120, 57.375],
|
| 49 |
+
to_rgb=True)
|
| 50 |
+
|
| 51 |
+
model = dict(
|
| 52 |
+
type='SparseBEV',
|
| 53 |
+
data_aug=dict(
|
| 54 |
+
img_color_aug=True, # Move some augmentations to GPU
|
| 55 |
+
img_norm_cfg=img_norm_cfg,
|
| 56 |
+
img_pad_cfg=dict(size_divisor=32)),
|
| 57 |
+
stop_prev_grad=False,
|
| 58 |
+
img_backbone=img_backbone,
|
| 59 |
+
img_neck=img_neck,
|
| 60 |
+
pts_bbox_head=dict(
|
| 61 |
+
type='SparseBEVHead',
|
| 62 |
+
num_classes=10,
|
| 63 |
+
in_channels=embed_dims,
|
| 64 |
+
num_query=num_query,
|
| 65 |
+
query_denoising=True,
|
| 66 |
+
query_denoising_groups=10,
|
| 67 |
+
code_size=10,
|
| 68 |
+
code_weights=[2.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
|
| 69 |
+
sync_cls_avg_factor=True,
|
| 70 |
+
transformer=dict(
|
| 71 |
+
type='SparseBEVTransformer',
|
| 72 |
+
embed_dims=embed_dims,
|
| 73 |
+
num_frames=num_frames,
|
| 74 |
+
num_points=num_points,
|
| 75 |
+
num_layers=num_layers,
|
| 76 |
+
num_levels=num_levels,
|
| 77 |
+
num_classes=10,
|
| 78 |
+
code_size=10,
|
| 79 |
+
pc_range=point_cloud_range),
|
| 80 |
+
bbox_coder=dict(
|
| 81 |
+
type='NMSFreeCoder',
|
| 82 |
+
post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
|
| 83 |
+
pc_range=point_cloud_range,
|
| 84 |
+
max_num=300,
|
| 85 |
+
voxel_size=voxel_size,
|
| 86 |
+
score_threshold=0.05,
|
| 87 |
+
num_classes=10),
|
| 88 |
+
positional_encoding=dict(
|
| 89 |
+
type='SinePositionalEncoding',
|
| 90 |
+
num_feats=embed_dims // 2,
|
| 91 |
+
normalize=True,
|
| 92 |
+
offset=-0.5),
|
| 93 |
+
loss_cls=dict(
|
| 94 |
+
type='FocalLoss',
|
| 95 |
+
use_sigmoid=True,
|
| 96 |
+
gamma=2.0,
|
| 97 |
+
alpha=0.25,
|
| 98 |
+
loss_weight=2.0),
|
| 99 |
+
loss_bbox=dict(type='L1Loss', loss_weight=0.25),
|
| 100 |
+
loss_iou=dict(type='GIoULoss', loss_weight=0.0)),
|
| 101 |
+
train_cfg=dict(pts=dict(
|
| 102 |
+
grid_size=[512, 512, 1],
|
| 103 |
+
voxel_size=voxel_size,
|
| 104 |
+
point_cloud_range=point_cloud_range,
|
| 105 |
+
out_size_factor=4,
|
| 106 |
+
assigner=dict(
|
| 107 |
+
type='HungarianAssigner3D',
|
| 108 |
+
cls_cost=dict(type='FocalLossCost', weight=2.0),
|
| 109 |
+
reg_cost=dict(type='BBox3DL1Cost', weight=0.25),
|
| 110 |
+
iou_cost=dict(type='IoUCost', weight=0.0),
|
| 111 |
+
)
|
| 112 |
+
))
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
ida_aug_conf = {
|
| 116 |
+
'resize_lim': (0.38, 0.55),
|
| 117 |
+
'final_dim': (256, 704),
|
| 118 |
+
'bot_pct_lim': (0.0, 0.0),
|
| 119 |
+
'rot_lim': (0.0, 0.0),
|
| 120 |
+
'H': 900, 'W': 1600,
|
| 121 |
+
'rand_flip': True,
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
train_pipeline = [
|
| 125 |
+
dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),
|
| 126 |
+
dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=num_frames - 1),
|
| 127 |
+
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True, with_attr_label=False),
|
| 128 |
+
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
|
| 129 |
+
dict(type='ObjectNameFilter', classes=class_names),
|
| 130 |
+
dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=True),
|
| 131 |
+
dict(type='GlobalRotScaleTransImage', rot_range=[-0.3925, 0.3925], scale_ratio_range=[0.95, 1.05]),
|
| 132 |
+
dict(type='DefaultFormatBundle3D', class_names=class_names),
|
| 133 |
+
dict(type='Collect3D', keys=['gt_bboxes_3d', 'gt_labels_3d', 'img'], meta_keys=(
|
| 134 |
+
'filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp'))
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
test_pipeline = [
|
| 138 |
+
dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),
|
| 139 |
+
dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=num_frames - 1, test_mode=True),
|
| 140 |
+
dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=False),
|
| 141 |
+
dict(
|
| 142 |
+
type='MultiScaleFlipAug3D',
|
| 143 |
+
img_scale=(1600, 900),
|
| 144 |
+
pts_scale_ratio=1,
|
| 145 |
+
flip=False,
|
| 146 |
+
transforms=[
|
| 147 |
+
dict(type='DefaultFormatBundle3D', class_names=class_names, with_label=False),
|
| 148 |
+
dict(type='Collect3D', keys=['img'], meta_keys=(
|
| 149 |
+
'filename', 'box_type_3d', 'ori_shape', 'img_shape', 'pad_shape',
|
| 150 |
+
'lidar2img', 'img_timestamp'))
|
| 151 |
+
])
|
| 152 |
+
]
|
| 153 |
+
|
| 154 |
+
data = dict(
|
| 155 |
+
workers_per_gpu=8,
|
| 156 |
+
train=dict(
|
| 157 |
+
type=dataset_type,
|
| 158 |
+
data_root=dataset_root,
|
| 159 |
+
ann_file=dataset_root + 'nuscenes_infos_train_sweep.pkl',
|
| 160 |
+
pipeline=train_pipeline,
|
| 161 |
+
classes=class_names,
|
| 162 |
+
modality=input_modality,
|
| 163 |
+
test_mode=False,
|
| 164 |
+
use_valid_flag=True,
|
| 165 |
+
box_type_3d='LiDAR'),
|
| 166 |
+
val=dict(
|
| 167 |
+
type=dataset_type,
|
| 168 |
+
data_root=dataset_root,
|
| 169 |
+
ann_file=dataset_root + 'nuscenes_infos_val_sweep.pkl',
|
| 170 |
+
pipeline=test_pipeline,
|
| 171 |
+
classes=class_names,
|
| 172 |
+
modality=input_modality,
|
| 173 |
+
test_mode=True,
|
| 174 |
+
box_type_3d='LiDAR'),
|
| 175 |
+
test=dict(
|
| 176 |
+
type=dataset_type,
|
| 177 |
+
data_root=dataset_root,
|
| 178 |
+
ann_file=dataset_root + 'nuscenes_custom_infos_test.pkl',
|
| 179 |
+
pipeline=test_pipeline,
|
| 180 |
+
classes=class_names,
|
| 181 |
+
modality=input_modality,
|
| 182 |
+
test_mode=True,
|
| 183 |
+
box_type_3d='LiDAR')
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
optimizer = dict(
|
| 187 |
+
type='AdamW',
|
| 188 |
+
lr=2e-4,
|
| 189 |
+
paramwise_cfg=dict(custom_keys={
|
| 190 |
+
'img_backbone': dict(lr_mult=0.1),
|
| 191 |
+
'sampling_offset': dict(lr_mult=0.1),
|
| 192 |
+
}),
|
| 193 |
+
weight_decay=0.01
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
optimizer_config = dict(
|
| 197 |
+
type='Fp16OptimizerHook',
|
| 198 |
+
loss_scale=512.0,
|
| 199 |
+
grad_clip=dict(max_norm=35, norm_type=2)
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# learning policy
|
| 203 |
+
lr_config = dict(
|
| 204 |
+
policy='CosineAnnealing',
|
| 205 |
+
warmup='linear',
|
| 206 |
+
warmup_iters=500,
|
| 207 |
+
warmup_ratio=1.0 / 3,
|
| 208 |
+
min_lr_ratio=1e-3
|
| 209 |
+
)
|
| 210 |
+
total_epochs = 36
|
| 211 |
+
batch_size = 8
|
| 212 |
+
|
| 213 |
+
# load pretrained weights
|
| 214 |
+
load_from = 'pretrain/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim_20201009_124951-40963960.pth'
|
| 215 |
+
revise_keys = [('backbone', 'img_backbone')]
|
| 216 |
+
|
| 217 |
+
# resume the last training
|
| 218 |
+
resume_from = None
|
| 219 |
+
|
| 220 |
+
# checkpointing
|
| 221 |
+
checkpoint_config = dict(interval=1, max_keep_ckpts=1)
|
| 222 |
+
|
| 223 |
+
# logging
|
| 224 |
+
log_config = dict(
|
| 225 |
+
interval=1,
|
| 226 |
+
hooks=[
|
| 227 |
+
dict(type='MyTextLoggerHook', interval=1, reset_flag=True),
|
| 228 |
+
dict(type='MyTensorboardLoggerHook', interval=500, reset_flag=True)
|
| 229 |
+
]
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# evaluation
|
| 233 |
+
eval_config = dict(interval=total_epochs)
|
| 234 |
+
|
| 235 |
+
# other flags
|
| 236 |
+
debug = False
|
gen_sweep_info.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generate info files manually
|
| 2 |
+
import os
|
| 3 |
+
import mmcv
|
| 4 |
+
import tqdm
|
| 5 |
+
import pickle
|
| 6 |
+
import argparse
|
| 7 |
+
import numpy as np
|
| 8 |
+
from nuscenes import NuScenes
|
| 9 |
+
from pyquaternion import Quaternion
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
parser = argparse.ArgumentParser()
|
| 13 |
+
parser.add_argument('--data-root', default='data/nuscenes')
|
| 14 |
+
parser.add_argument('--version', default='v1.0-trainval')
|
| 15 |
+
args = parser.parse_args()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_cam_info(nusc, sample_data):
|
| 19 |
+
pose_record = nusc.get('ego_pose', sample_data['ego_pose_token'])
|
| 20 |
+
cs_record = nusc.get('calibrated_sensor', sample_data['calibrated_sensor_token'])
|
| 21 |
+
|
| 22 |
+
sensor2ego_translation = cs_record['translation']
|
| 23 |
+
ego2global_translation = pose_record['translation']
|
| 24 |
+
sensor2ego_rotation = Quaternion(cs_record['rotation']).rotation_matrix
|
| 25 |
+
ego2global_rotation = Quaternion(pose_record['rotation']).rotation_matrix
|
| 26 |
+
cam_intrinsic = np.array(cs_record['camera_intrinsic'])
|
| 27 |
+
|
| 28 |
+
sensor2global_rotation = sensor2ego_rotation.T @ ego2global_rotation.T
|
| 29 |
+
sensor2global_translation = sensor2ego_translation @ ego2global_rotation.T + ego2global_translation
|
| 30 |
+
|
| 31 |
+
return {
|
| 32 |
+
'data_path': os.path.join(args.data_root, sample_data['filename']),
|
| 33 |
+
'sensor2global_rotation': sensor2global_rotation,
|
| 34 |
+
'sensor2global_translation': sensor2global_translation,
|
| 35 |
+
'cam_intrinsic': cam_intrinsic,
|
| 36 |
+
'timestamp': sample_data['timestamp'],
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def add_sweep_info(nusc, sample_infos):
|
| 41 |
+
for curr_id in tqdm.tqdm(range(len(sample_infos['infos']))):
|
| 42 |
+
sample = nusc.get('sample', sample_infos['infos'][curr_id]['token'])
|
| 43 |
+
|
| 44 |
+
cam_types = [
|
| 45 |
+
'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_RIGHT',
|
| 46 |
+
'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_FRONT_LEFT'
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
curr_cams = dict()
|
| 50 |
+
for cam in cam_types:
|
| 51 |
+
curr_cams[cam] = nusc.get('sample_data', sample['data'][cam])
|
| 52 |
+
|
| 53 |
+
for cam in cam_types:
|
| 54 |
+
sample_data = nusc.get('sample_data', sample['data'][cam])
|
| 55 |
+
sweep_cam = get_cam_info(nusc, sample_data)
|
| 56 |
+
sample_infos['infos'][curr_id]['cams'][cam].update(sweep_cam)
|
| 57 |
+
|
| 58 |
+
# remove unnecessary
|
| 59 |
+
for cam in cam_types:
|
| 60 |
+
del sample_infos['infos'][curr_id]['cams'][cam]['sample_data_token']
|
| 61 |
+
del sample_infos['infos'][curr_id]['cams'][cam]['sensor2ego_translation']
|
| 62 |
+
del sample_infos['infos'][curr_id]['cams'][cam]['sensor2ego_rotation']
|
| 63 |
+
del sample_infos['infos'][curr_id]['cams'][cam]['ego2global_translation']
|
| 64 |
+
del sample_infos['infos'][curr_id]['cams'][cam]['ego2global_rotation']
|
| 65 |
+
|
| 66 |
+
sweep_infos = []
|
| 67 |
+
if sample['prev'] != '': # add sweep frame between two key frame
|
| 68 |
+
for _ in range(5):
|
| 69 |
+
sweep_info = dict()
|
| 70 |
+
for cam in cam_types:
|
| 71 |
+
if curr_cams[cam]['prev'] == '':
|
| 72 |
+
sweep_info = sweep_infos[-1]
|
| 73 |
+
break
|
| 74 |
+
sample_data = nusc.get('sample_data', curr_cams[cam]['prev'])
|
| 75 |
+
sweep_cam = get_cam_info(nusc, sample_data)
|
| 76 |
+
curr_cams[cam] = sample_data
|
| 77 |
+
sweep_info[cam] = sweep_cam
|
| 78 |
+
sweep_infos.append(sweep_info)
|
| 79 |
+
|
| 80 |
+
sample_infos['infos'][curr_id]['sweeps'] = sweep_infos
|
| 81 |
+
|
| 82 |
+
return sample_infos
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if __name__ == '__main__':
|
| 86 |
+
nusc = NuScenes(args.version, args.data_root)
|
| 87 |
+
|
| 88 |
+
if args.version == 'v1.0-trainval':
|
| 89 |
+
sample_infos = pickle.load(open(os.path.join(args.data_root, 'nuscenes_infos_train.pkl'), 'rb'))
|
| 90 |
+
sample_infos = add_sweep_info(nusc, sample_infos)
|
| 91 |
+
mmcv.dump(sample_infos, os.path.join(args.data_root, 'nuscenes_infos_train_sweep.pkl'))
|
| 92 |
+
|
| 93 |
+
sample_infos = pickle.load(open(os.path.join(args.data_root, 'nuscenes_infos_val.pkl'), 'rb'))
|
| 94 |
+
sample_infos = add_sweep_info(nusc, sample_infos)
|
| 95 |
+
mmcv.dump(sample_infos, os.path.join(args.data_root, 'nuscenes_infos_val_sweep.pkl'))
|
| 96 |
+
|
| 97 |
+
elif args.version == 'v1.0-test':
|
| 98 |
+
sample_infos = pickle.load(open(os.path.join(args.data_root, 'nuscenes_infos_test.pkl'), 'rb'))
|
| 99 |
+
sample_infos = add_sweep_info(nusc, sample_infos)
|
| 100 |
+
mmcv.dump(sample_infos, os.path.join(args.data_root, 'nuscenes_infos_test_sweep.pkl'))
|
| 101 |
+
|
| 102 |
+
elif args.version == 'v1.0-mini':
|
| 103 |
+
sample_infos = pickle.load(open(os.path.join(args.data_root, 'nuscenes_infos_train_mini.pkl'), 'rb'))
|
| 104 |
+
sample_infos = add_sweep_info(nusc, sample_infos)
|
| 105 |
+
mmcv.dump(sample_infos, os.path.join(args.data_root, 'nuscenes_infos_train_mini_sweep.pkl'))
|
| 106 |
+
|
| 107 |
+
sample_infos = pickle.load(open(os.path.join(args.data_root, 'nuscenes_infos_val_mini.pkl'), 'rb'))
|
| 108 |
+
sample_infos = add_sweep_info(nusc, sample_infos)
|
| 109 |
+
mmcv.dump(sample_infos, os.path.join(args.data_root, 'nuscenes_infos_val_mini_sweep.pkl'))
|
| 110 |
+
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError
|
loaders/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .pipelines import __all__
|
| 2 |
+
from .nuscenes_dataset import CustomNuScenesDataset
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
'CustomNuScenesDataset'
|
| 6 |
+
]
|
loaders/builder.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from mmcv.parallel import collate
|
| 3 |
+
from mmcv.runner import get_dist_info
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from mmdet.datasets.builder import worker_init_fn
|
| 6 |
+
from mmdet.datasets.samplers import DistributedGroupSampler, DistributedSampler, GroupSampler
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def build_dataloader(dataset,
|
| 10 |
+
samples_per_gpu,
|
| 11 |
+
workers_per_gpu,
|
| 12 |
+
num_gpus=1,
|
| 13 |
+
dist=True,
|
| 14 |
+
shuffle=True,
|
| 15 |
+
seed=None,
|
| 16 |
+
**kwargs):
|
| 17 |
+
|
| 18 |
+
rank, world_size = get_dist_info()
|
| 19 |
+
if dist:
|
| 20 |
+
# DistributedGroupSampler will definitely shuffle the data to satisfy
|
| 21 |
+
# that images on each GPU are in the same group
|
| 22 |
+
if shuffle:
|
| 23 |
+
sampler = DistributedGroupSampler(
|
| 24 |
+
dataset, samples_per_gpu, world_size, rank, seed=seed)
|
| 25 |
+
else:
|
| 26 |
+
sampler = DistributedSampler(
|
| 27 |
+
dataset, world_size, rank, shuffle=False, seed=seed)
|
| 28 |
+
batch_size = samples_per_gpu
|
| 29 |
+
num_workers = workers_per_gpu
|
| 30 |
+
else:
|
| 31 |
+
sampler = GroupSampler(dataset, samples_per_gpu) if shuffle else None
|
| 32 |
+
batch_size = num_gpus * samples_per_gpu
|
| 33 |
+
num_workers = num_gpus * workers_per_gpu
|
| 34 |
+
|
| 35 |
+
init_fn = partial(
|
| 36 |
+
worker_init_fn, num_workers=num_workers, rank=rank,
|
| 37 |
+
seed=seed) if seed is not None else None
|
| 38 |
+
|
| 39 |
+
data_loader = DataLoader(
|
| 40 |
+
dataset,
|
| 41 |
+
batch_size=batch_size,
|
| 42 |
+
sampler=sampler,
|
| 43 |
+
num_workers=num_workers,
|
| 44 |
+
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
|
| 45 |
+
pin_memory=False,
|
| 46 |
+
worker_init_fn=init_fn,
|
| 47 |
+
**kwargs)
|
| 48 |
+
|
| 49 |
+
return data_loader
|
loaders/nuscenes_dataset.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
from mmdet.datasets import DATASETS
|
| 4 |
+
from mmdet3d.datasets import NuScenesDataset
|
| 5 |
+
from pyquaternion import Quaternion
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@DATASETS.register_module()
|
| 9 |
+
class CustomNuScenesDataset(NuScenesDataset):
|
| 10 |
+
|
| 11 |
+
def collect_sweeps(self, index, into_past=60, into_future=0):
|
| 12 |
+
all_sweeps_prev = []
|
| 13 |
+
curr_index = index
|
| 14 |
+
while len(all_sweeps_prev) < into_past:
|
| 15 |
+
curr_sweeps = self.data_infos[curr_index]['sweeps']
|
| 16 |
+
if len(curr_sweeps) == 0:
|
| 17 |
+
break
|
| 18 |
+
all_sweeps_prev.extend(curr_sweeps)
|
| 19 |
+
all_sweeps_prev.append(self.data_infos[curr_index - 1]['cams'])
|
| 20 |
+
curr_index = curr_index - 1
|
| 21 |
+
|
| 22 |
+
all_sweeps_next = []
|
| 23 |
+
curr_index = index + 1
|
| 24 |
+
while len(all_sweeps_next) < into_future:
|
| 25 |
+
if curr_index >= len(self.data_infos):
|
| 26 |
+
break
|
| 27 |
+
curr_sweeps = self.data_infos[curr_index]['sweeps']
|
| 28 |
+
all_sweeps_next.extend(curr_sweeps[::-1])
|
| 29 |
+
all_sweeps_next.append(self.data_infos[curr_index]['cams'])
|
| 30 |
+
curr_index = curr_index + 1
|
| 31 |
+
|
| 32 |
+
return all_sweeps_prev, all_sweeps_next
|
| 33 |
+
|
| 34 |
+
def get_data_info(self, index):
|
| 35 |
+
info = self.data_infos[index]
|
| 36 |
+
sweeps_prev, sweeps_next = self.collect_sweeps(index)
|
| 37 |
+
|
| 38 |
+
ego2global_translation = info['ego2global_translation']
|
| 39 |
+
ego2global_rotation = info['ego2global_rotation']
|
| 40 |
+
lidar2ego_translation = info['lidar2ego_translation']
|
| 41 |
+
lidar2ego_rotation = info['lidar2ego_rotation']
|
| 42 |
+
ego2global_rotation = Quaternion(ego2global_rotation).rotation_matrix
|
| 43 |
+
lidar2ego_rotation = Quaternion(lidar2ego_rotation).rotation_matrix
|
| 44 |
+
|
| 45 |
+
input_dict = dict(
|
| 46 |
+
sample_idx=info['token'],
|
| 47 |
+
sweeps={'prev': sweeps_prev, 'next': sweeps_next},
|
| 48 |
+
timestamp=info['timestamp'] / 1e6,
|
| 49 |
+
ego2global_translation=ego2global_translation,
|
| 50 |
+
ego2global_rotation=ego2global_rotation,
|
| 51 |
+
lidar2ego_translation=lidar2ego_translation,
|
| 52 |
+
lidar2ego_rotation=lidar2ego_rotation,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
if self.modality['use_camera']:
|
| 56 |
+
img_paths = []
|
| 57 |
+
img_timestamps = []
|
| 58 |
+
lidar2img_rts = []
|
| 59 |
+
|
| 60 |
+
for _, cam_info in info['cams'].items():
|
| 61 |
+
img_paths.append(os.path.relpath(cam_info['data_path']))
|
| 62 |
+
img_timestamps.append(cam_info['timestamp'] / 1e6)
|
| 63 |
+
|
| 64 |
+
# obtain lidar to image transformation matrix
|
| 65 |
+
lidar2cam_r = np.linalg.inv(cam_info['sensor2lidar_rotation'])
|
| 66 |
+
lidar2cam_t = cam_info['sensor2lidar_translation'] @ lidar2cam_r.T
|
| 67 |
+
|
| 68 |
+
lidar2cam_rt = np.eye(4)
|
| 69 |
+
lidar2cam_rt[:3, :3] = lidar2cam_r.T
|
| 70 |
+
lidar2cam_rt[3, :3] = -lidar2cam_t
|
| 71 |
+
|
| 72 |
+
intrinsic = cam_info['cam_intrinsic']
|
| 73 |
+
viewpad = np.eye(4)
|
| 74 |
+
viewpad[:intrinsic.shape[0], :intrinsic.shape[1]] = intrinsic
|
| 75 |
+
lidar2img_rt = (viewpad @ lidar2cam_rt.T)
|
| 76 |
+
lidar2img_rts.append(lidar2img_rt)
|
| 77 |
+
|
| 78 |
+
input_dict.update(dict(
|
| 79 |
+
img_filename=img_paths,
|
| 80 |
+
img_timestamp=img_timestamps,
|
| 81 |
+
lidar2img=lidar2img_rts,
|
| 82 |
+
))
|
| 83 |
+
|
| 84 |
+
if not self.test_mode:
|
| 85 |
+
annos = self.get_ann_info(index)
|
| 86 |
+
input_dict['ann_info'] = annos
|
| 87 |
+
|
| 88 |
+
return input_dict
|
loaders/pipelines/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .loading import LoadMultiViewImageFromMultiSweeps
|
| 2 |
+
from .transforms import PadMultiViewImage, NormalizeMultiviewImage, PhotoMetricDistortionMultiViewImage
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
'LoadMultiViewImageFromMultiSweeps', 'PadMultiViewImage', 'NormalizeMultiviewImage',
|
| 6 |
+
'PhotoMetricDistortionMultiViewImage'
|
| 7 |
+
]
|
loaders/pipelines/loading.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import mmcv
|
| 3 |
+
import numpy as np
|
| 4 |
+
from mmdet.datasets.builder import PIPELINES
|
| 5 |
+
from numpy.linalg import inv
|
| 6 |
+
from mmcv.runner import get_dist_info
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def compose_lidar2img(ego2global_translation_curr,
|
| 10 |
+
ego2global_rotation_curr,
|
| 11 |
+
lidar2ego_translation_curr,
|
| 12 |
+
lidar2ego_rotation_curr,
|
| 13 |
+
sensor2global_translation_past,
|
| 14 |
+
sensor2global_rotation_past,
|
| 15 |
+
cam_intrinsic_past):
|
| 16 |
+
|
| 17 |
+
R = sensor2global_rotation_past @ (inv(ego2global_rotation_curr).T @ inv(lidar2ego_rotation_curr).T)
|
| 18 |
+
T = sensor2global_translation_past @ (inv(ego2global_rotation_curr).T @ inv(lidar2ego_rotation_curr).T)
|
| 19 |
+
T -= ego2global_translation_curr @ (inv(ego2global_rotation_curr).T @ inv(lidar2ego_rotation_curr).T) + lidar2ego_translation_curr @ inv(lidar2ego_rotation_curr).T
|
| 20 |
+
|
| 21 |
+
lidar2cam_r = inv(R.T)
|
| 22 |
+
lidar2cam_t = T @ lidar2cam_r.T
|
| 23 |
+
|
| 24 |
+
lidar2cam_rt = np.eye(4)
|
| 25 |
+
lidar2cam_rt[:3, :3] = lidar2cam_r.T
|
| 26 |
+
lidar2cam_rt[3, :3] = -lidar2cam_t
|
| 27 |
+
|
| 28 |
+
viewpad = np.eye(4)
|
| 29 |
+
viewpad[:cam_intrinsic_past.shape[0], :cam_intrinsic_past.shape[1]] = cam_intrinsic_past
|
| 30 |
+
lidar2img = (viewpad @ lidar2cam_rt.T).astype(np.float32)
|
| 31 |
+
|
| 32 |
+
return lidar2img
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@PIPELINES.register_module()
|
| 36 |
+
class LoadMultiViewImageFromMultiSweeps(object):
|
| 37 |
+
def __init__(self,
|
| 38 |
+
sweeps_num=5,
|
| 39 |
+
color_type='color',
|
| 40 |
+
test_mode=False):
|
| 41 |
+
self.sweeps_num = sweeps_num
|
| 42 |
+
self.color_type = color_type
|
| 43 |
+
self.test_mode = test_mode
|
| 44 |
+
|
| 45 |
+
self.train_interval = [4, 8]
|
| 46 |
+
self.test_interval = 6
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
mmcv.use_backend('turbojpeg')
|
| 50 |
+
except ImportError:
|
| 51 |
+
mmcv.use_backend('cv2')
|
| 52 |
+
|
| 53 |
+
def load_offline(self, results):
|
| 54 |
+
cam_types = [
|
| 55 |
+
'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT',
|
| 56 |
+
'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT'
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
if len(results['sweeps']['prev']) == 0:
|
| 60 |
+
for _ in range(self.sweeps_num):
|
| 61 |
+
for j in range(len(cam_types)):
|
| 62 |
+
results['img'].append(results['img'][j])
|
| 63 |
+
results['img_timestamp'].append(results['img_timestamp'][j])
|
| 64 |
+
results['filename'].append(results['filename'][j])
|
| 65 |
+
results['lidar2img'].append(np.copy(results['lidar2img'][j]))
|
| 66 |
+
else:
|
| 67 |
+
if self.test_mode:
|
| 68 |
+
interval = self.test_interval
|
| 69 |
+
choices = [(k + 1) * interval - 1 for k in range(self.sweeps_num)]
|
| 70 |
+
elif len(results['sweeps']['prev']) <= self.sweeps_num:
|
| 71 |
+
pad_len = self.sweeps_num - len(results['sweeps']['prev'])
|
| 72 |
+
choices = list(range(len(results['sweeps']['prev']))) + [len(results['sweeps']['prev']) - 1] * pad_len
|
| 73 |
+
else:
|
| 74 |
+
max_interval = len(results['sweeps']['prev']) // self.sweeps_num
|
| 75 |
+
max_interval = min(max_interval, self.train_interval[1])
|
| 76 |
+
min_interval = min(max_interval, self.train_interval[0])
|
| 77 |
+
interval = np.random.randint(min_interval, max_interval + 1)
|
| 78 |
+
choices = [(k + 1) * interval - 1 for k in range(self.sweeps_num)]
|
| 79 |
+
|
| 80 |
+
for idx in sorted(list(choices)):
|
| 81 |
+
sweep_idx = min(idx, len(results['sweeps']['prev']) - 1)
|
| 82 |
+
sweep = results['sweeps']['prev'][sweep_idx]
|
| 83 |
+
|
| 84 |
+
if len(sweep.keys()) < len(cam_types):
|
| 85 |
+
sweep = results['sweeps']['prev'][sweep_idx - 1]
|
| 86 |
+
|
| 87 |
+
for sensor in cam_types:
|
| 88 |
+
results['img'].append(mmcv.imread(sweep[sensor]['data_path'], self.color_type))
|
| 89 |
+
results['img_timestamp'].append(sweep[sensor]['timestamp'] / 1e6)
|
| 90 |
+
results['filename'].append(os.path.relpath(sweep[sensor]['data_path']))
|
| 91 |
+
results['lidar2img'].append(compose_lidar2img(
|
| 92 |
+
results['ego2global_translation'],
|
| 93 |
+
results['ego2global_rotation'],
|
| 94 |
+
results['lidar2ego_translation'],
|
| 95 |
+
results['lidar2ego_rotation'],
|
| 96 |
+
sweep[sensor]['sensor2global_translation'],
|
| 97 |
+
sweep[sensor]['sensor2global_rotation'],
|
| 98 |
+
sweep[sensor]['cam_intrinsic'],
|
| 99 |
+
))
|
| 100 |
+
|
| 101 |
+
return results
|
| 102 |
+
|
| 103 |
+
def load_online(self, results):
|
| 104 |
+
# only used when measuring FPS
|
| 105 |
+
assert self.test_mode
|
| 106 |
+
assert self.test_interval == 6
|
| 107 |
+
|
| 108 |
+
cam_types = [
|
| 109 |
+
'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT',
|
| 110 |
+
'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT'
|
| 111 |
+
]
|
| 112 |
+
|
| 113 |
+
if len(results['sweeps']['prev']) == 0:
|
| 114 |
+
for _ in range(self.sweeps_num):
|
| 115 |
+
for j in range(len(cam_types)):
|
| 116 |
+
results['img_timestamp'].append(results['img_timestamp'][j])
|
| 117 |
+
results['filename'].append(results['filename'][j])
|
| 118 |
+
results['lidar2img'].append(np.copy(results['lidar2img'][j]))
|
| 119 |
+
else:
|
| 120 |
+
interval = self.test_interval
|
| 121 |
+
choices = [(k + 1) * interval - 1 for k in range(self.sweeps_num)]
|
| 122 |
+
|
| 123 |
+
for idx in sorted(list(choices)):
|
| 124 |
+
sweep_idx = min(idx, len(results['sweeps']['prev']) - 1)
|
| 125 |
+
sweep = results['sweeps']['prev'][sweep_idx]
|
| 126 |
+
|
| 127 |
+
if len(sweep.keys()) < len(cam_types):
|
| 128 |
+
sweep = results['sweeps']['prev'][sweep_idx - 1]
|
| 129 |
+
|
| 130 |
+
for sensor in cam_types:
|
| 131 |
+
# skip loading history frames
|
| 132 |
+
results['img_timestamp'].append(sweep[sensor]['timestamp'] / 1e6)
|
| 133 |
+
results['filename'].append(os.path.relpath(sweep[sensor]['data_path']))
|
| 134 |
+
results['lidar2img'].append(compose_lidar2img(
|
| 135 |
+
results['ego2global_translation'],
|
| 136 |
+
results['ego2global_rotation'],
|
| 137 |
+
results['lidar2ego_translation'],
|
| 138 |
+
results['lidar2ego_rotation'],
|
| 139 |
+
sweep[sensor]['sensor2global_translation'],
|
| 140 |
+
sweep[sensor]['sensor2global_rotation'],
|
| 141 |
+
sweep[sensor]['cam_intrinsic'],
|
| 142 |
+
))
|
| 143 |
+
|
| 144 |
+
return results
|
| 145 |
+
|
| 146 |
+
def __call__(self, results):
|
| 147 |
+
if self.sweeps_num == 0:
|
| 148 |
+
return results
|
| 149 |
+
|
| 150 |
+
world_size = get_dist_info()[1]
|
| 151 |
+
if world_size == 1 and self.test_mode:
|
| 152 |
+
return self.load_online(results)
|
| 153 |
+
else:
|
| 154 |
+
return self.load_offline(results)
|
loaders/pipelines/transforms.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import mmcv
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from numpy import random
|
| 6 |
+
from mmdet.datasets.builder import PIPELINES
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@PIPELINES.register_module()
|
| 10 |
+
class PadMultiViewImage(object):
|
| 11 |
+
"""Pad the multi-view image.
|
| 12 |
+
There are two padding modes: (1) pad to a fixed size and (2) pad to the
|
| 13 |
+
minimum size that is divisible by some number.
|
| 14 |
+
Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
|
| 15 |
+
Args:
|
| 16 |
+
size (tuple, optional): Fixed padding size.
|
| 17 |
+
size_divisor (int, optional): The divisor of padded size.
|
| 18 |
+
pad_val (float, optional): Padding value, 0 by default.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, size=None, size_divisor=None, pad_val=0):
|
| 22 |
+
self.size = size
|
| 23 |
+
self.size_divisor = size_divisor
|
| 24 |
+
self.pad_val = pad_val
|
| 25 |
+
# only one of size and size_divisor should be valid
|
| 26 |
+
assert size is not None or size_divisor is not None
|
| 27 |
+
assert size is None or size_divisor is None
|
| 28 |
+
|
| 29 |
+
def _pad_img(self, img):
|
| 30 |
+
if self.size_divisor is not None:
|
| 31 |
+
pad_h = int(np.ceil(img.shape[0] / self.size_divisor)) * self.size_divisor
|
| 32 |
+
pad_w = int(np.ceil(img.shape[1] / self.size_divisor)) * self.size_divisor
|
| 33 |
+
else:
|
| 34 |
+
pad_h, pad_w = self.size
|
| 35 |
+
|
| 36 |
+
pad_width = ((0, pad_h - img.shape[0]), (0, pad_w - img.shape[1]), (0, 0))
|
| 37 |
+
img = np.pad(img, pad_width, constant_values=self.pad_val)
|
| 38 |
+
return img
|
| 39 |
+
|
| 40 |
+
def _pad_imgs(self, results):
|
| 41 |
+
padded_img = [self._pad_img(img) for img in results['img']]
|
| 42 |
+
|
| 43 |
+
results['ori_shape'] = [img.shape for img in results['img']]
|
| 44 |
+
results['img'] = padded_img
|
| 45 |
+
results['img_shape'] = [img.shape for img in padded_img]
|
| 46 |
+
results['pad_shape'] = [img.shape for img in padded_img]
|
| 47 |
+
results['pad_fixed_size'] = self.size
|
| 48 |
+
results['pad_size_divisor'] = self.size_divisor
|
| 49 |
+
|
| 50 |
+
def __call__(self, results):
|
| 51 |
+
"""Call function to pad images, masks, semantic segmentation maps.
|
| 52 |
+
Args:
|
| 53 |
+
results (dict): Result dict from loading pipeline.
|
| 54 |
+
Returns:
|
| 55 |
+
dict: Updated result dict.
|
| 56 |
+
"""
|
| 57 |
+
self._pad_imgs(results)
|
| 58 |
+
return results
|
| 59 |
+
|
| 60 |
+
def __repr__(self):
|
| 61 |
+
repr_str = self.__class__.__name__
|
| 62 |
+
repr_str += f'(size={self.size}, '
|
| 63 |
+
repr_str += f'size_divisor={self.size_divisor}, '
|
| 64 |
+
repr_str += f'pad_val={self.pad_val})'
|
| 65 |
+
return repr_str
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@PIPELINES.register_module()
|
| 69 |
+
class NormalizeMultiviewImage(object):
|
| 70 |
+
"""Normalize the image.
|
| 71 |
+
Added key is "img_norm_cfg".
|
| 72 |
+
Args:
|
| 73 |
+
mean (sequence): Mean values of 3 channels.
|
| 74 |
+
std (sequence): Std values of 3 channels.
|
| 75 |
+
to_rgb (bool): Whether to convert the image from BGR to RGB,
|
| 76 |
+
default is true.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, mean, std, to_rgb=True):
|
| 80 |
+
self.mean = np.array(mean, dtype=np.float32).reshape(-1)
|
| 81 |
+
self.std = 1 / np.array(std, dtype=np.float32).reshape(-1)
|
| 82 |
+
self.to_rgb = to_rgb
|
| 83 |
+
|
| 84 |
+
def __call__(self, results):
|
| 85 |
+
"""Call function to normalize images.
|
| 86 |
+
Args:
|
| 87 |
+
results (dict): Result dict from loading pipeline.
|
| 88 |
+
Returns:
|
| 89 |
+
dict: Normalized results, 'img_norm_cfg' key is added into
|
| 90 |
+
result dict.
|
| 91 |
+
"""
|
| 92 |
+
normalized_imgs = []
|
| 93 |
+
|
| 94 |
+
for img in results['img']:
|
| 95 |
+
img = img.astype(np.float32)
|
| 96 |
+
if self.to_rgb:
|
| 97 |
+
img = img[..., ::-1]
|
| 98 |
+
img = img - self.mean
|
| 99 |
+
img = img * self.std
|
| 100 |
+
normalized_imgs.append(img)
|
| 101 |
+
|
| 102 |
+
results['img'] = normalized_imgs
|
| 103 |
+
results['img_norm_cfg'] = dict(
|
| 104 |
+
mean=self.mean,
|
| 105 |
+
std=self.std,
|
| 106 |
+
to_rgb=self.to_rgb
|
| 107 |
+
)
|
| 108 |
+
return results
|
| 109 |
+
|
| 110 |
+
def __repr__(self):
|
| 111 |
+
repr_str = self.__class__.__name__
|
| 112 |
+
repr_str += f'(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})'
|
| 113 |
+
return repr_str
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@PIPELINES.register_module()
|
| 117 |
+
class PhotoMetricDistortionMultiViewImage:
|
| 118 |
+
"""Apply photometric distortion to image sequentially, every transformation
|
| 119 |
+
is applied with a probability of 0.5. The position of random contrast is in
|
| 120 |
+
second or second to last.
|
| 121 |
+
1. random brightness
|
| 122 |
+
2. random contrast (mode 0)
|
| 123 |
+
3. convert color from BGR to HSV
|
| 124 |
+
4. random saturation
|
| 125 |
+
5. random hue
|
| 126 |
+
6. convert color from HSV to BGR
|
| 127 |
+
7. random contrast (mode 1)
|
| 128 |
+
8. randomly swap channels
|
| 129 |
+
Args:
|
| 130 |
+
brightness_delta (int): delta of brightness.
|
| 131 |
+
contrast_range (tuple): range of contrast.
|
| 132 |
+
saturation_range (tuple): range of saturation.
|
| 133 |
+
hue_delta (int): delta of hue.
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
def __init__(self,
|
| 137 |
+
brightness_delta=32,
|
| 138 |
+
contrast_range=(0.5, 1.5),
|
| 139 |
+
saturation_range=(0.5, 1.5),
|
| 140 |
+
hue_delta=18):
|
| 141 |
+
self.brightness_delta = brightness_delta
|
| 142 |
+
self.contrast_lower, self.contrast_upper = contrast_range
|
| 143 |
+
self.saturation_lower, self.saturation_upper = saturation_range
|
| 144 |
+
self.hue_delta = hue_delta
|
| 145 |
+
|
| 146 |
+
def __call__(self, results):
|
| 147 |
+
"""Call function to perform photometric distortion on images.
|
| 148 |
+
Args:
|
| 149 |
+
results (dict): Result dict from loading pipeline.
|
| 150 |
+
Returns:
|
| 151 |
+
dict: Result dict with images distorted.
|
| 152 |
+
"""
|
| 153 |
+
imgs = results['img']
|
| 154 |
+
new_imgs = []
|
| 155 |
+
for img in imgs:
|
| 156 |
+
ori_dtype = img.dtype
|
| 157 |
+
img = img.astype(np.float32)
|
| 158 |
+
|
| 159 |
+
# random brightness
|
| 160 |
+
if random.randint(2):
|
| 161 |
+
delta = random.uniform(-self.brightness_delta,
|
| 162 |
+
self.brightness_delta)
|
| 163 |
+
img += delta
|
| 164 |
+
|
| 165 |
+
# mode == 0 --> do random contrast first
|
| 166 |
+
# mode == 1 --> do random contrast last
|
| 167 |
+
mode = random.randint(2)
|
| 168 |
+
if mode == 1:
|
| 169 |
+
if random.randint(2):
|
| 170 |
+
alpha = random.uniform(self.contrast_lower,
|
| 171 |
+
self.contrast_upper)
|
| 172 |
+
img *= alpha
|
| 173 |
+
|
| 174 |
+
# convert color from BGR to HSV
|
| 175 |
+
img = mmcv.bgr2hsv(img)
|
| 176 |
+
|
| 177 |
+
# random saturation
|
| 178 |
+
if random.randint(2):
|
| 179 |
+
img[..., 1] *= random.uniform(self.saturation_lower,
|
| 180 |
+
self.saturation_upper)
|
| 181 |
+
|
| 182 |
+
# random hue
|
| 183 |
+
if random.randint(2):
|
| 184 |
+
img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
|
| 185 |
+
img[..., 0][img[..., 0] > 360] -= 360
|
| 186 |
+
img[..., 0][img[..., 0] < 0] += 360
|
| 187 |
+
|
| 188 |
+
# convert color from HSV to BGR
|
| 189 |
+
img = mmcv.hsv2bgr(img)
|
| 190 |
+
|
| 191 |
+
# random contrast
|
| 192 |
+
if mode == 0:
|
| 193 |
+
if random.randint(2):
|
| 194 |
+
alpha = random.uniform(self.contrast_lower,
|
| 195 |
+
self.contrast_upper)
|
| 196 |
+
img *= alpha
|
| 197 |
+
|
| 198 |
+
# randomly swap channels
|
| 199 |
+
if random.randint(2):
|
| 200 |
+
img = img[..., random.permutation(3)]
|
| 201 |
+
|
| 202 |
+
new_imgs.append(img.astype(ori_dtype))
|
| 203 |
+
|
| 204 |
+
results['img'] = new_imgs
|
| 205 |
+
return results
|
| 206 |
+
|
| 207 |
+
def __repr__(self):
|
| 208 |
+
repr_str = self.__class__.__name__
|
| 209 |
+
repr_str += f'(\nbrightness_delta={self.brightness_delta},\n'
|
| 210 |
+
repr_str += 'contrast_range='
|
| 211 |
+
repr_str += f'{(self.contrast_lower, self.contrast_upper)},\n'
|
| 212 |
+
repr_str += 'saturation_range='
|
| 213 |
+
repr_str += f'{(self.saturation_lower, self.saturation_upper)},\n'
|
| 214 |
+
repr_str += f'hue_delta={self.hue_delta})'
|
| 215 |
+
return repr_str
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
@PIPELINES.register_module()
|
| 219 |
+
class RandomTransformImage(object):
|
| 220 |
+
def __init__(self, ida_aug_conf=None, training=True):
|
| 221 |
+
self.ida_aug_conf = ida_aug_conf
|
| 222 |
+
self.training = training
|
| 223 |
+
|
| 224 |
+
def __call__(self, results):
|
| 225 |
+
resize, resize_dims, crop, flip, rotate = self.sample_augmentation()
|
| 226 |
+
|
| 227 |
+
if len(results['lidar2img']) == len(results['img']):
|
| 228 |
+
for i in range(len(results['img'])):
|
| 229 |
+
img = Image.fromarray(np.uint8(results['img'][i]))
|
| 230 |
+
|
| 231 |
+
# resize, resize_dims, crop, flip, rotate = self._sample_augmentation()
|
| 232 |
+
img, ida_mat = self.img_transform(
|
| 233 |
+
img,
|
| 234 |
+
resize=resize,
|
| 235 |
+
resize_dims=resize_dims,
|
| 236 |
+
crop=crop,
|
| 237 |
+
flip=flip,
|
| 238 |
+
rotate=rotate,
|
| 239 |
+
)
|
| 240 |
+
results['img'][i] = np.array(img).astype(np.uint8)
|
| 241 |
+
results['lidar2img'][i] = ida_mat @ results['lidar2img'][i]
|
| 242 |
+
|
| 243 |
+
elif len(results['img']) == 6:
|
| 244 |
+
for i in range(len(results['img'])):
|
| 245 |
+
img = Image.fromarray(np.uint8(results['img'][i]))
|
| 246 |
+
|
| 247 |
+
# resize, resize_dims, crop, flip, rotate = self._sample_augmentation()
|
| 248 |
+
img, ida_mat = self.img_transform(
|
| 249 |
+
img,
|
| 250 |
+
resize=resize,
|
| 251 |
+
resize_dims=resize_dims,
|
| 252 |
+
crop=crop,
|
| 253 |
+
flip=flip,
|
| 254 |
+
rotate=rotate,
|
| 255 |
+
)
|
| 256 |
+
results['img'][i] = np.array(img).astype(np.uint8)
|
| 257 |
+
|
| 258 |
+
for i in range(len(results['lidar2img'])):
|
| 259 |
+
results['lidar2img'][i] = ida_mat @ results['lidar2img'][i]
|
| 260 |
+
|
| 261 |
+
else:
|
| 262 |
+
raise ValueError()
|
| 263 |
+
|
| 264 |
+
results['ori_shape'] = [img.shape for img in results['img']]
|
| 265 |
+
results['img_shape'] = [img.shape for img in results['img']]
|
| 266 |
+
results['pad_shape'] = [img.shape for img in results['img']]
|
| 267 |
+
|
| 268 |
+
return results
|
| 269 |
+
|
| 270 |
+
def img_transform(self, img, resize, resize_dims, crop, flip, rotate):
|
| 271 |
+
"""
|
| 272 |
+
https://github.com/Megvii-BaseDetection/BEVStereo/blob/master/dataset/nusc_mv_det_dataset.py#L48
|
| 273 |
+
"""
|
| 274 |
+
def get_rot(h):
|
| 275 |
+
return torch.Tensor([
|
| 276 |
+
[np.cos(h), np.sin(h)],
|
| 277 |
+
[-np.sin(h), np.cos(h)],
|
| 278 |
+
])
|
| 279 |
+
|
| 280 |
+
ida_rot = torch.eye(2)
|
| 281 |
+
ida_tran = torch.zeros(2)
|
| 282 |
+
|
| 283 |
+
# adjust image
|
| 284 |
+
img = img.resize(resize_dims)
|
| 285 |
+
img = img.crop(crop)
|
| 286 |
+
if flip:
|
| 287 |
+
img = img.transpose(method=Image.FLIP_LEFT_RIGHT)
|
| 288 |
+
img = img.rotate(rotate)
|
| 289 |
+
|
| 290 |
+
# post-homography transformation
|
| 291 |
+
ida_rot *= resize
|
| 292 |
+
ida_tran -= torch.Tensor(crop[:2])
|
| 293 |
+
|
| 294 |
+
if flip:
|
| 295 |
+
A = torch.Tensor([[-1, 0], [0, 1]])
|
| 296 |
+
b = torch.Tensor([crop[2] - crop[0], 0])
|
| 297 |
+
ida_rot = A.matmul(ida_rot)
|
| 298 |
+
ida_tran = A.matmul(ida_tran) + b
|
| 299 |
+
|
| 300 |
+
A = get_rot(rotate / 180 * np.pi)
|
| 301 |
+
b = torch.Tensor([crop[2] - crop[0], crop[3] - crop[1]]) / 2
|
| 302 |
+
b = A.matmul(-b) + b
|
| 303 |
+
|
| 304 |
+
ida_rot = A.matmul(ida_rot)
|
| 305 |
+
ida_tran = A.matmul(ida_tran) + b
|
| 306 |
+
|
| 307 |
+
ida_mat = torch.eye(4)
|
| 308 |
+
ida_mat[:2, :2] = ida_rot
|
| 309 |
+
ida_mat[:2, 2] = ida_tran
|
| 310 |
+
|
| 311 |
+
return img, ida_mat.numpy()
|
| 312 |
+
|
| 313 |
+
def sample_augmentation(self):
|
| 314 |
+
"""
|
| 315 |
+
https://github.com/Megvii-BaseDetection/BEVStereo/blob/master/dataset/nusc_mv_det_dataset.py#L247
|
| 316 |
+
"""
|
| 317 |
+
H, W = self.ida_aug_conf['H'], self.ida_aug_conf['W']
|
| 318 |
+
fH, fW = self.ida_aug_conf['final_dim']
|
| 319 |
+
|
| 320 |
+
if self.training:
|
| 321 |
+
resize = np.random.uniform(*self.ida_aug_conf['resize_lim'])
|
| 322 |
+
resize_dims = (int(W * resize), int(H * resize))
|
| 323 |
+
newW, newH = resize_dims
|
| 324 |
+
crop_h = int((1 - np.random.uniform(*self.ida_aug_conf['bot_pct_lim'])) * newH) - fH
|
| 325 |
+
crop_w = int(np.random.uniform(0, max(0, newW - fW)))
|
| 326 |
+
crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
|
| 327 |
+
flip = False
|
| 328 |
+
if self.ida_aug_conf['rand_flip'] and np.random.choice([0, 1]):
|
| 329 |
+
flip = True
|
| 330 |
+
rotate = np.random.uniform(*self.ida_aug_conf['rot_lim'])
|
| 331 |
+
else:
|
| 332 |
+
resize = max(fH / H, fW / W)
|
| 333 |
+
resize_dims = (int(W * resize), int(H * resize))
|
| 334 |
+
newW, newH = resize_dims
|
| 335 |
+
crop_h = int((1 - np.mean(self.ida_aug_conf['bot_pct_lim'])) * newH) - fH
|
| 336 |
+
crop_w = int(max(0, newW - fW) / 2)
|
| 337 |
+
crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
|
| 338 |
+
flip = False
|
| 339 |
+
rotate = 0
|
| 340 |
+
|
| 341 |
+
return resize, resize_dims, crop, flip, rotate
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
@PIPELINES.register_module()
|
| 345 |
+
class GlobalRotScaleTransImage(object):
|
| 346 |
+
def __init__(self,
|
| 347 |
+
rot_range=[-0.3925, 0.3925],
|
| 348 |
+
scale_ratio_range=[0.95, 1.05],
|
| 349 |
+
translation_std=[0, 0, 0]):
|
| 350 |
+
self.rot_range = rot_range
|
| 351 |
+
self.scale_ratio_range = scale_ratio_range
|
| 352 |
+
self.translation_std = translation_std
|
| 353 |
+
|
| 354 |
+
def __call__(self, results):
|
| 355 |
+
# random rotate
|
| 356 |
+
rot_angle = np.random.uniform(*self.rot_range)
|
| 357 |
+
self.rotate_z(results, rot_angle)
|
| 358 |
+
results["gt_bboxes_3d"].rotate(np.array(rot_angle))
|
| 359 |
+
|
| 360 |
+
# random scale
|
| 361 |
+
scale_ratio = np.random.uniform(*self.scale_ratio_range)
|
| 362 |
+
self.scale_xyz(results, scale_ratio)
|
| 363 |
+
results["gt_bboxes_3d"].scale(scale_ratio)
|
| 364 |
+
|
| 365 |
+
# TODO: support translation
|
| 366 |
+
|
| 367 |
+
return results
|
| 368 |
+
|
| 369 |
+
def rotate_z(self, results, rot_angle):
|
| 370 |
+
rot_cos = torch.cos(torch.tensor(rot_angle))
|
| 371 |
+
rot_sin = torch.sin(torch.tensor(rot_angle))
|
| 372 |
+
|
| 373 |
+
rot_mat = torch.tensor([
|
| 374 |
+
[rot_cos, -rot_sin, 0, 0],
|
| 375 |
+
[rot_sin, rot_cos, 0, 0],
|
| 376 |
+
[0, 0, 1, 0],
|
| 377 |
+
[0, 0, 0, 1],
|
| 378 |
+
])
|
| 379 |
+
rot_mat_inv = torch.inverse(rot_mat)
|
| 380 |
+
|
| 381 |
+
for view in range(len(results['lidar2img'])):
|
| 382 |
+
results['lidar2img'][view] = (torch.tensor(results['lidar2img'][view]).float() @ rot_mat_inv).numpy()
|
| 383 |
+
|
| 384 |
+
def scale_xyz(self, results, scale_ratio):
|
| 385 |
+
scale_mat = torch.tensor([
|
| 386 |
+
[scale_ratio, 0, 0, 0],
|
| 387 |
+
[0, scale_ratio, 0, 0],
|
| 388 |
+
[0, 0, scale_ratio, 0],
|
| 389 |
+
[0, 0, 0, 1],
|
| 390 |
+
])
|
| 391 |
+
scale_mat_inv = torch.inverse(scale_mat)
|
| 392 |
+
|
| 393 |
+
for view in range(len(results['lidar2img'])):
|
| 394 |
+
results['lidar2img'][view] = (torch.tensor(results['lidar2img'][view]).float() @ scale_mat_inv).numpy()
|
models/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .backbones import __all__
|
| 2 |
+
from .bbox import __all__
|
| 3 |
+
from .sparsebev import SparseBEV
|
| 4 |
+
from .sparsebev_head import SparseBEVHead
|
| 5 |
+
from .sparsebev_transformer import SparseBEVTransformer
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
'SparseBEV', 'SparseBEVHead', 'SparseBEVTransformer'
|
| 9 |
+
]
|
models/backbones/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .vovnet import VoVNet
|
| 2 |
+
|
| 3 |
+
__all__ = ['VoVNet']
|
models/backbones/vovnet.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import warnings
|
| 5 |
+
import torch.utils.checkpoint as cp
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
from mmcv.runner import BaseModule
|
| 8 |
+
from mmdet.models.builder import BACKBONES
|
| 9 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
VoVNet19_slim_dw_eSE = {
|
| 13 |
+
'stem': [64, 64, 64],
|
| 14 |
+
'stage_conv_ch': [64, 80, 96, 112],
|
| 15 |
+
'stage_out_ch': [112, 256, 384, 512],
|
| 16 |
+
"layer_per_block": 3,
|
| 17 |
+
"block_per_stage": [1, 1, 1, 1],
|
| 18 |
+
"eSE": True,
|
| 19 |
+
"dw": True
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
VoVNet19_dw_eSE = {
|
| 23 |
+
'stem': [64, 64, 64],
|
| 24 |
+
"stage_conv_ch": [128, 160, 192, 224],
|
| 25 |
+
"stage_out_ch": [256, 512, 768, 1024],
|
| 26 |
+
"layer_per_block": 3,
|
| 27 |
+
"block_per_stage": [1, 1, 1, 1],
|
| 28 |
+
"eSE": True,
|
| 29 |
+
"dw": True
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
VoVNet19_slim_eSE = {
|
| 33 |
+
'stem': [64, 64, 128],
|
| 34 |
+
'stage_conv_ch': [64, 80, 96, 112],
|
| 35 |
+
'stage_out_ch': [112, 256, 384, 512],
|
| 36 |
+
'layer_per_block': 3,
|
| 37 |
+
'block_per_stage': [1, 1, 1, 1],
|
| 38 |
+
'eSE': True,
|
| 39 |
+
"dw": False
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
VoVNet19_eSE = {
|
| 43 |
+
'stem': [64, 64, 128],
|
| 44 |
+
"stage_conv_ch": [128, 160, 192, 224],
|
| 45 |
+
"stage_out_ch": [256, 512, 768, 1024],
|
| 46 |
+
"layer_per_block": 3,
|
| 47 |
+
"block_per_stage": [1, 1, 1, 1],
|
| 48 |
+
"eSE": True,
|
| 49 |
+
"dw": False
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
VoVNet39_eSE = {
|
| 53 |
+
'stem': [64, 64, 128],
|
| 54 |
+
"stage_conv_ch": [128, 160, 192, 224],
|
| 55 |
+
"stage_out_ch": [256, 512, 768, 1024],
|
| 56 |
+
"layer_per_block": 5,
|
| 57 |
+
"block_per_stage": [1, 1, 2, 2],
|
| 58 |
+
"eSE": True,
|
| 59 |
+
"dw": False
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
VoVNet57_eSE = {
|
| 63 |
+
'stem': [64, 64, 128],
|
| 64 |
+
"stage_conv_ch": [128, 160, 192, 224],
|
| 65 |
+
"stage_out_ch": [256, 512, 768, 1024],
|
| 66 |
+
"layer_per_block": 5,
|
| 67 |
+
"block_per_stage": [1, 1, 4, 3],
|
| 68 |
+
"eSE": True,
|
| 69 |
+
"dw": False
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
VoVNet99_eSE = {
|
| 73 |
+
'stem': [64, 64, 128],
|
| 74 |
+
"stage_conv_ch": [128, 160, 192, 224],
|
| 75 |
+
"stage_out_ch": [256, 512, 768, 1024],
|
| 76 |
+
"layer_per_block": 5,
|
| 77 |
+
"block_per_stage": [1, 3, 9, 3],
|
| 78 |
+
"eSE": True,
|
| 79 |
+
"dw": False
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
_STAGE_SPECS = {
|
| 83 |
+
"V-19-slim-dw-eSE": VoVNet19_slim_dw_eSE,
|
| 84 |
+
"V-19-dw-eSE": VoVNet19_dw_eSE,
|
| 85 |
+
"V-19-slim-eSE": VoVNet19_slim_eSE,
|
| 86 |
+
"V-19-eSE": VoVNet19_eSE,
|
| 87 |
+
"V-39-eSE": VoVNet39_eSE,
|
| 88 |
+
"V-57-eSE": VoVNet57_eSE,
|
| 89 |
+
"V-99-eSE": VoVNet99_eSE,
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def dw_conv3x3(in_channels, out_channels, module_name, postfix, stride=1, kernel_size=3, padding=1):
|
| 94 |
+
"""3x3 convolution with padding"""
|
| 95 |
+
return [
|
| 96 |
+
(
|
| 97 |
+
'{}_{}/dw_conv3x3'.format(module_name, postfix),
|
| 98 |
+
nn.Conv2d(
|
| 99 |
+
in_channels,
|
| 100 |
+
out_channels,
|
| 101 |
+
kernel_size=kernel_size,
|
| 102 |
+
stride=stride,
|
| 103 |
+
padding=padding,
|
| 104 |
+
groups=out_channels,
|
| 105 |
+
bias=False
|
| 106 |
+
)
|
| 107 |
+
),
|
| 108 |
+
(
|
| 109 |
+
'{}_{}/pw_conv1x1'.format(module_name, postfix),
|
| 110 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, groups=1, bias=False)
|
| 111 |
+
),
|
| 112 |
+
('{}_{}/pw_norm'.format(module_name, postfix), nn.BatchNorm2d(out_channels)),
|
| 113 |
+
('{}_{}/pw_relu'.format(module_name, postfix), nn.ReLU(inplace=True)),
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def conv3x3(in_channels, out_channels, module_name, postfix, stride=1, groups=1, kernel_size=3, padding=1):
|
| 118 |
+
"""3x3 convolution with padding"""
|
| 119 |
+
return [
|
| 120 |
+
(
|
| 121 |
+
f"{module_name}_{postfix}/conv",
|
| 122 |
+
nn.Conv2d(
|
| 123 |
+
in_channels,
|
| 124 |
+
out_channels,
|
| 125 |
+
kernel_size=kernel_size,
|
| 126 |
+
stride=stride,
|
| 127 |
+
padding=padding,
|
| 128 |
+
groups=groups,
|
| 129 |
+
bias=False,
|
| 130 |
+
),
|
| 131 |
+
),
|
| 132 |
+
(f"{module_name}_{postfix}/norm", nn.BatchNorm2d(out_channels)),
|
| 133 |
+
(f"{module_name}_{postfix}/relu", nn.ReLU(inplace=True)),
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def conv1x1(in_channels, out_channels, module_name, postfix, stride=1, groups=1, kernel_size=1, padding=0):
|
| 138 |
+
"""1x1 convolution with padding"""
|
| 139 |
+
return [
|
| 140 |
+
(
|
| 141 |
+
f"{module_name}_{postfix}/conv",
|
| 142 |
+
nn.Conv2d(
|
| 143 |
+
in_channels,
|
| 144 |
+
out_channels,
|
| 145 |
+
kernel_size=kernel_size,
|
| 146 |
+
stride=stride,
|
| 147 |
+
padding=padding,
|
| 148 |
+
groups=groups,
|
| 149 |
+
bias=False,
|
| 150 |
+
),
|
| 151 |
+
),
|
| 152 |
+
(f"{module_name}_{postfix}/norm", nn.BatchNorm2d(out_channels)),
|
| 153 |
+
(f"{module_name}_{postfix}/relu", nn.ReLU(inplace=True)),
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class Hsigmoid(nn.Module):
|
| 158 |
+
def __init__(self, inplace=True):
|
| 159 |
+
super(Hsigmoid, self).__init__()
|
| 160 |
+
self.inplace = inplace
|
| 161 |
+
|
| 162 |
+
def forward(self, x):
|
| 163 |
+
return F.relu6(x + 3.0, inplace=self.inplace) / 6.0
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class eSEModule(nn.Module):
|
| 167 |
+
def __init__(self, channel, reduction=4):
|
| 168 |
+
super(eSEModule, self).__init__()
|
| 169 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 170 |
+
self.fc = nn.Conv2d(channel, channel, kernel_size=1, padding=0)
|
| 171 |
+
self.hsigmoid = Hsigmoid()
|
| 172 |
+
|
| 173 |
+
def forward(self, x):
|
| 174 |
+
inputs = x
|
| 175 |
+
x = self.avg_pool(x)
|
| 176 |
+
x = self.fc(x)
|
| 177 |
+
x = self.hsigmoid(x)
|
| 178 |
+
return inputs * x
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class _OSA_module(nn.Module):
|
| 182 |
+
def __init__(self, in_ch, stage_ch, concat_ch, layer_per_block, module_name, SE=False, identity=False, depthwise=False, with_cp=False):
|
| 183 |
+
super(_OSA_module, self).__init__()
|
| 184 |
+
self.with_cp = with_cp
|
| 185 |
+
|
| 186 |
+
self.identity = identity
|
| 187 |
+
self.depthwise = depthwise
|
| 188 |
+
self.isReduced = False
|
| 189 |
+
self.layers = nn.ModuleList()
|
| 190 |
+
in_channel = in_ch
|
| 191 |
+
|
| 192 |
+
if self.depthwise and in_channel != stage_ch:
|
| 193 |
+
self.isReduced = True
|
| 194 |
+
self.conv_reduction = nn.Sequential(
|
| 195 |
+
OrderedDict(conv1x1(in_channel, stage_ch, "{}_reduction".format(module_name), "0"))
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
for i in range(layer_per_block):
|
| 199 |
+
if self.depthwise:
|
| 200 |
+
self.layers.append(nn.Sequential(OrderedDict(dw_conv3x3(stage_ch, stage_ch, module_name, i))))
|
| 201 |
+
else:
|
| 202 |
+
self.layers.append(nn.Sequential(OrderedDict(conv3x3(in_channel, stage_ch, module_name, i))))
|
| 203 |
+
in_channel = stage_ch
|
| 204 |
+
|
| 205 |
+
# feature aggregation
|
| 206 |
+
in_channel = in_ch + layer_per_block * stage_ch
|
| 207 |
+
self.concat = nn.Sequential(OrderedDict(conv1x1(in_channel, concat_ch, module_name, "concat")))
|
| 208 |
+
|
| 209 |
+
self.ese = eSEModule(concat_ch)
|
| 210 |
+
|
| 211 |
+
def _forward(self, x):
|
| 212 |
+
identity_feat = x
|
| 213 |
+
|
| 214 |
+
output = []
|
| 215 |
+
output.append(x)
|
| 216 |
+
|
| 217 |
+
if self.depthwise and self.isReduced:
|
| 218 |
+
x = self.conv_reduction(x)
|
| 219 |
+
|
| 220 |
+
for layer in self.layers:
|
| 221 |
+
x = layer(x)
|
| 222 |
+
output.append(x)
|
| 223 |
+
|
| 224 |
+
x = torch.cat(output, dim=1)
|
| 225 |
+
xt = self.concat(x)
|
| 226 |
+
|
| 227 |
+
xt = self.ese(xt)
|
| 228 |
+
|
| 229 |
+
if self.identity:
|
| 230 |
+
xt = xt + identity_feat
|
| 231 |
+
|
| 232 |
+
return xt
|
| 233 |
+
|
| 234 |
+
def forward(self, x):
|
| 235 |
+
if self.with_cp and self.training and x.requires_grad:
|
| 236 |
+
return cp.checkpoint(self._forward, x)
|
| 237 |
+
else:
|
| 238 |
+
return self._forward(x)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class _OSA_stage(nn.Sequential):
|
| 242 |
+
def __init__(self, in_ch, stage_ch, concat_ch, block_per_stage, layer_per_block, stage_num, SE=False, depthwise=False, with_cp=False):
|
| 243 |
+
super(_OSA_stage, self).__init__()
|
| 244 |
+
if not stage_num == 2:
|
| 245 |
+
self.add_module("Pooling", nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True))
|
| 246 |
+
|
| 247 |
+
if block_per_stage != 1:
|
| 248 |
+
SE = False
|
| 249 |
+
|
| 250 |
+
module_name = f"OSA{stage_num}_1"
|
| 251 |
+
self.add_module(
|
| 252 |
+
module_name, _OSA_module(in_ch, stage_ch, concat_ch, layer_per_block, module_name, SE, depthwise=depthwise, with_cp=with_cp)
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
for i in range(block_per_stage - 1):
|
| 256 |
+
if i != block_per_stage - 2: # last block
|
| 257 |
+
SE = False
|
| 258 |
+
module_name = f"OSA{stage_num}_{i + 2}"
|
| 259 |
+
self.add_module(
|
| 260 |
+
module_name,
|
| 261 |
+
_OSA_module(
|
| 262 |
+
concat_ch,
|
| 263 |
+
stage_ch,
|
| 264 |
+
concat_ch,
|
| 265 |
+
layer_per_block,
|
| 266 |
+
module_name,
|
| 267 |
+
SE,
|
| 268 |
+
identity=True,
|
| 269 |
+
depthwise=depthwise,
|
| 270 |
+
with_cp=with_cp
|
| 271 |
+
),
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
@BACKBONES.register_module()
|
| 276 |
+
class VoVNet(BaseModule):
|
| 277 |
+
def __init__(self, spec_name, input_ch=3, out_features=None, frozen_stages=-1, norm_eval=True, with_cp=False, pretrained=None, init_cfg=None):
|
| 278 |
+
"""
|
| 279 |
+
Args:
|
| 280 |
+
input_ch(int) : the number of input channel
|
| 281 |
+
out_features (list[str]): name of the layers whose outputs should
|
| 282 |
+
be returned in forward. Can be anything in "stem", "stage2" ...
|
| 283 |
+
"""
|
| 284 |
+
super(VoVNet, self).__init__(init_cfg)
|
| 285 |
+
self.frozen_stages = frozen_stages
|
| 286 |
+
self.norm_eval = norm_eval
|
| 287 |
+
|
| 288 |
+
if isinstance(pretrained, str):
|
| 289 |
+
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
| 290 |
+
'please use "init_cfg" instead')
|
| 291 |
+
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
| 292 |
+
stage_specs = _STAGE_SPECS[spec_name]
|
| 293 |
+
|
| 294 |
+
stem_ch = stage_specs["stem"]
|
| 295 |
+
config_stage_ch = stage_specs["stage_conv_ch"]
|
| 296 |
+
config_concat_ch = stage_specs["stage_out_ch"]
|
| 297 |
+
block_per_stage = stage_specs["block_per_stage"]
|
| 298 |
+
layer_per_block = stage_specs["layer_per_block"]
|
| 299 |
+
SE = stage_specs["eSE"]
|
| 300 |
+
depthwise = stage_specs["dw"]
|
| 301 |
+
|
| 302 |
+
self._out_features = out_features
|
| 303 |
+
|
| 304 |
+
# Stem module
|
| 305 |
+
conv_type = dw_conv3x3 if depthwise else conv3x3
|
| 306 |
+
stem = conv3x3(input_ch, stem_ch[0], "stem", "1", 2)
|
| 307 |
+
stem += conv_type(stem_ch[0], stem_ch[1], "stem", "2", 1)
|
| 308 |
+
stem += conv_type(stem_ch[1], stem_ch[2], "stem", "3", 2)
|
| 309 |
+
self.add_module("stem", nn.Sequential((OrderedDict(stem))))
|
| 310 |
+
current_stirde = 4
|
| 311 |
+
self._out_feature_strides = {"stem": current_stirde, "stage2": current_stirde}
|
| 312 |
+
self._out_feature_channels = {"stem": stem_ch[2]}
|
| 313 |
+
|
| 314 |
+
stem_out_ch = [stem_ch[2]]
|
| 315 |
+
in_ch_list = stem_out_ch + config_concat_ch[:-1]
|
| 316 |
+
|
| 317 |
+
# OSA stages
|
| 318 |
+
self.stage_names = []
|
| 319 |
+
for i in range(4): # num_stages
|
| 320 |
+
name = "stage%d" % (i + 2) # stage 2 ... stage 5
|
| 321 |
+
self.stage_names.append(name)
|
| 322 |
+
self.add_module(
|
| 323 |
+
name,
|
| 324 |
+
_OSA_stage(
|
| 325 |
+
in_ch_list[i],
|
| 326 |
+
config_stage_ch[i],
|
| 327 |
+
config_concat_ch[i],
|
| 328 |
+
block_per_stage[i],
|
| 329 |
+
layer_per_block,
|
| 330 |
+
i + 2,
|
| 331 |
+
SE,
|
| 332 |
+
depthwise,
|
| 333 |
+
with_cp=with_cp
|
| 334 |
+
),
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
self._out_feature_channels[name] = config_concat_ch[i]
|
| 338 |
+
if not i == 0:
|
| 339 |
+
self._out_feature_strides[name] = current_stirde = int(current_stirde * 2)
|
| 340 |
+
|
| 341 |
+
# initialize weights
|
| 342 |
+
# self._initialize_weights()
|
| 343 |
+
|
| 344 |
+
def _initialize_weights(self):
|
| 345 |
+
for m in self.modules():
|
| 346 |
+
if isinstance(m, nn.Conv2d):
|
| 347 |
+
nn.init.kaiming_normal_(m.weight)
|
| 348 |
+
|
| 349 |
+
def forward(self, x):
|
| 350 |
+
outputs = {}
|
| 351 |
+
x = self.stem(x)
|
| 352 |
+
if "stem" in self._out_features:
|
| 353 |
+
outputs["stem"] = x
|
| 354 |
+
for name in self.stage_names:
|
| 355 |
+
x = getattr(self, name)(x)
|
| 356 |
+
if name in self._out_features:
|
| 357 |
+
outputs[name] = x
|
| 358 |
+
|
| 359 |
+
return outputs
|
| 360 |
+
|
| 361 |
+
def _freeze_stages(self):
|
| 362 |
+
if self.frozen_stages >= 0:
|
| 363 |
+
m = getattr(self, 'stem')
|
| 364 |
+
m.eval()
|
| 365 |
+
for param in m.parameters():
|
| 366 |
+
param.requires_grad = False
|
| 367 |
+
|
| 368 |
+
for i in range(1, self.frozen_stages + 1):
|
| 369 |
+
m = getattr(self, f'stage{i+1}')
|
| 370 |
+
m.eval()
|
| 371 |
+
for param in m.parameters():
|
| 372 |
+
param.requires_grad = False
|
| 373 |
+
|
| 374 |
+
def train(self, mode=True):
|
| 375 |
+
"""Convert the model into training mode while keep normalization layer
|
| 376 |
+
freezed."""
|
| 377 |
+
super(VoVNet, self).train(mode)
|
| 378 |
+
self._freeze_stages()
|
| 379 |
+
if mode and self.norm_eval:
|
| 380 |
+
for m in self.modules():
|
| 381 |
+
# trick: eval have effect on BatchNorm only
|
| 382 |
+
if isinstance(m, _BatchNorm):
|
| 383 |
+
m.eval()
|
models/bbox/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .assigners import __all__
|
| 2 |
+
from .coders import __all__
|
| 3 |
+
from .match_costs import __all__
|
models/bbox/assigners/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .hungarian_assigner_3d import HungarianAssigner3D
|
| 2 |
+
|
| 3 |
+
__all__ = ['HungarianAssigner3D']
|
models/bbox/assigners/hungarian_assigner_3d.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from mmdet.core.bbox.builder import BBOX_ASSIGNERS
|
| 4 |
+
from mmdet.core.bbox.assigners import AssignResult
|
| 5 |
+
from mmdet.core.bbox.assigners import BaseAssigner
|
| 6 |
+
from mmdet.core.bbox.match_costs import build_match_cost
|
| 7 |
+
from ..utils import normalize_bbox
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from scipy.optimize import linear_sum_assignment
|
| 11 |
+
except ImportError:
|
| 12 |
+
linear_sum_assignment = None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@BBOX_ASSIGNERS.register_module()
|
| 16 |
+
class HungarianAssigner3D(BaseAssigner):
|
| 17 |
+
def __init__(self,
|
| 18 |
+
cls_cost=dict(type='ClassificationCost', weight=1.),
|
| 19 |
+
reg_cost=dict(type='BBoxL1Cost', weight=1.0),
|
| 20 |
+
iou_cost=dict(type='IoUCost', weight=0.0),
|
| 21 |
+
pc_range=None):
|
| 22 |
+
self.cls_cost = build_match_cost(cls_cost)
|
| 23 |
+
self.reg_cost = build_match_cost(reg_cost)
|
| 24 |
+
self.iou_cost = build_match_cost(iou_cost)
|
| 25 |
+
self.pc_range = pc_range
|
| 26 |
+
|
| 27 |
+
def assign(self,
|
| 28 |
+
bbox_pred,
|
| 29 |
+
cls_pred,
|
| 30 |
+
gt_bboxes,
|
| 31 |
+
gt_labels,
|
| 32 |
+
gt_bboxes_ignore=None,
|
| 33 |
+
code_weights=None,
|
| 34 |
+
with_velo=False):
|
| 35 |
+
assert gt_bboxes_ignore is None, \
|
| 36 |
+
'Only case when gt_bboxes_ignore is None is supported.'
|
| 37 |
+
num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)
|
| 38 |
+
|
| 39 |
+
# 1. assign -1 by default
|
| 40 |
+
assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),
|
| 41 |
+
-1,
|
| 42 |
+
dtype=torch.long)
|
| 43 |
+
assigned_labels = bbox_pred.new_full((num_bboxes, ),
|
| 44 |
+
-1,
|
| 45 |
+
dtype=torch.long)
|
| 46 |
+
if num_gts == 0 or num_bboxes == 0:
|
| 47 |
+
# No ground truth or boxes, return empty assignment
|
| 48 |
+
if num_gts == 0:
|
| 49 |
+
# No ground truth, assign all to background
|
| 50 |
+
assigned_gt_inds[:] = 0
|
| 51 |
+
return AssignResult(
|
| 52 |
+
num_gts, assigned_gt_inds, None, labels=assigned_labels)
|
| 53 |
+
|
| 54 |
+
# 2. compute the weighted costs
|
| 55 |
+
# classification and bboxcost.
|
| 56 |
+
cls_cost = self.cls_cost(cls_pred, gt_labels)
|
| 57 |
+
# regression L1 cost
|
| 58 |
+
normalized_gt_bboxes = normalize_bbox(gt_bboxes)
|
| 59 |
+
|
| 60 |
+
if code_weights is not None:
|
| 61 |
+
bbox_pred = bbox_pred * code_weights
|
| 62 |
+
normalized_gt_bboxes = normalized_gt_bboxes * code_weights
|
| 63 |
+
|
| 64 |
+
if with_velo:
|
| 65 |
+
reg_cost = self.reg_cost(bbox_pred, normalized_gt_bboxes)
|
| 66 |
+
else:
|
| 67 |
+
reg_cost = self.reg_cost(bbox_pred[:, :8], normalized_gt_bboxes[:, :8])
|
| 68 |
+
|
| 69 |
+
# weighted sum of above two costs
|
| 70 |
+
cost = cls_cost + reg_cost
|
| 71 |
+
|
| 72 |
+
# 3. do Hungarian matching on CPU using linear_sum_assignment
|
| 73 |
+
cost = cost.detach().cpu()
|
| 74 |
+
cost = torch.nan_to_num(cost, nan=100.0, posinf=100.0, neginf=-100.0)
|
| 75 |
+
|
| 76 |
+
if linear_sum_assignment is None:
|
| 77 |
+
raise ImportError('Please run "pip install scipy" '
|
| 78 |
+
'to install scipy first.')
|
| 79 |
+
|
| 80 |
+
matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
|
| 81 |
+
matched_row_inds = torch.from_numpy(matched_row_inds).to(
|
| 82 |
+
bbox_pred.device)
|
| 83 |
+
matched_col_inds = torch.from_numpy(matched_col_inds).to(
|
| 84 |
+
bbox_pred.device)
|
| 85 |
+
|
| 86 |
+
# 4. assign backgrounds and foregrounds
|
| 87 |
+
# assign all indices to backgrounds first
|
| 88 |
+
assigned_gt_inds[:] = 0
|
| 89 |
+
# assign foregrounds based on matching results
|
| 90 |
+
assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
|
| 91 |
+
assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
|
| 92 |
+
return AssignResult(
|
| 93 |
+
num_gts, assigned_gt_inds, None, labels=assigned_labels)
|
models/bbox/coders/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .nms_free_coder import NMSFreeCoder
|
| 2 |
+
|
| 3 |
+
__all__ = ['NMSFreeCoder']
|
models/bbox/coders/nms_free_coder.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from mmdet.core.bbox import BaseBBoxCoder
|
| 4 |
+
from mmdet.core.bbox.builder import BBOX_CODERS
|
| 5 |
+
from ..utils import denormalize_bbox
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@BBOX_CODERS.register_module()
|
| 9 |
+
class NMSFreeCoder(BaseBBoxCoder):
|
| 10 |
+
"""Bbox coder for NMS-free detector.
|
| 11 |
+
Args:
|
| 12 |
+
pc_range (list[float]): Range of point cloud.
|
| 13 |
+
post_center_range (list[float]): Limit of the center.
|
| 14 |
+
Default: None.
|
| 15 |
+
max_num (int): Max number to be kept. Default: 100.
|
| 16 |
+
score_threshold (float): Threshold to filter boxes based on score.
|
| 17 |
+
Default: None.
|
| 18 |
+
code_size (int): Code size of bboxes. Default: 9
|
| 19 |
+
"""
|
| 20 |
+
def __init__(self,
|
| 21 |
+
pc_range,
|
| 22 |
+
voxel_size=None,
|
| 23 |
+
post_center_range=None,
|
| 24 |
+
max_num=100,
|
| 25 |
+
score_threshold=None,
|
| 26 |
+
num_classes=10):
|
| 27 |
+
self.pc_range = pc_range
|
| 28 |
+
self.voxel_size = voxel_size
|
| 29 |
+
self.post_center_range = post_center_range
|
| 30 |
+
self.max_num = max_num
|
| 31 |
+
self.score_threshold = score_threshold
|
| 32 |
+
self.num_classes = num_classes
|
| 33 |
+
|
| 34 |
+
def encode(self):
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
def decode_single(self, cls_scores, bbox_preds):
|
| 38 |
+
"""Decode bboxes.
|
| 39 |
+
Args:
|
| 40 |
+
cls_scores (Tensor): Outputs from the classification head, \
|
| 41 |
+
shape [num_query, cls_out_channels]. Note \
|
| 42 |
+
cls_out_channels should includes background.
|
| 43 |
+
bbox_preds (Tensor): Outputs from the regression \
|
| 44 |
+
head with normalized coordinate format (cx, cy, w, l, cz, h, rot_sine, rot_cosine, vx, vy). \
|
| 45 |
+
Shape [num_query, 9].
|
| 46 |
+
Returns:
|
| 47 |
+
list[dict]: Decoded boxes.
|
| 48 |
+
"""
|
| 49 |
+
max_num = self.max_num
|
| 50 |
+
|
| 51 |
+
cls_scores = cls_scores.sigmoid()
|
| 52 |
+
scores, indexs = cls_scores.view(-1).topk(max_num)
|
| 53 |
+
labels = indexs % self.num_classes
|
| 54 |
+
bbox_index = torch.div(indexs, self.num_classes, rounding_mode='trunc')
|
| 55 |
+
bbox_preds = bbox_preds[bbox_index]
|
| 56 |
+
|
| 57 |
+
final_box_preds = denormalize_bbox(bbox_preds)
|
| 58 |
+
final_scores = scores
|
| 59 |
+
final_preds = labels
|
| 60 |
+
|
| 61 |
+
# use score threshold
|
| 62 |
+
if self.score_threshold is not None:
|
| 63 |
+
thresh_mask = final_scores > self.score_threshold
|
| 64 |
+
|
| 65 |
+
if self.post_center_range is not None:
|
| 66 |
+
limit = torch.tensor(self.post_center_range, device=scores.device)
|
| 67 |
+
mask = (final_box_preds[..., :3] >= limit[:3]).all(1)
|
| 68 |
+
mask &= (final_box_preds[..., :3] <= limit[3:]).all(1)
|
| 69 |
+
|
| 70 |
+
if self.score_threshold:
|
| 71 |
+
mask &= thresh_mask
|
| 72 |
+
|
| 73 |
+
boxes3d = final_box_preds[mask]
|
| 74 |
+
scores = final_scores[mask]
|
| 75 |
+
labels = final_preds[mask]
|
| 76 |
+
predictions_dict = {
|
| 77 |
+
'bboxes': boxes3d,
|
| 78 |
+
'scores': scores,
|
| 79 |
+
'labels': labels
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
else:
|
| 83 |
+
raise NotImplementedError(
|
| 84 |
+
'Need to reorganize output as a batch, only '
|
| 85 |
+
'support post_center_range is not None for now!'
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
return predictions_dict
|
| 89 |
+
|
| 90 |
+
def decode(self, preds_dicts):
|
| 91 |
+
"""Decode bboxes.
|
| 92 |
+
Args:
|
| 93 |
+
all_cls_scores (Tensor): Outputs from the classification head, \
|
| 94 |
+
shape [nb_dec, bs, num_query, cls_out_channels]. Note \
|
| 95 |
+
cls_out_channels should includes background.
|
| 96 |
+
all_bbox_preds (Tensor): Sigmoid outputs from the regression \
|
| 97 |
+
head with normalized coordinate format (cx, cy, w, l, cz, h, rot_sine, rot_cosine, vx, vy). \
|
| 98 |
+
Shape [nb_dec, bs, num_query, 9].
|
| 99 |
+
Returns:
|
| 100 |
+
list[dict]: Decoded boxes.
|
| 101 |
+
"""
|
| 102 |
+
all_cls_scores = preds_dicts['all_cls_scores'][-1]
|
| 103 |
+
all_bbox_preds = preds_dicts['all_bbox_preds'][-1]
|
| 104 |
+
|
| 105 |
+
batch_size = all_cls_scores.size()[0]
|
| 106 |
+
predictions_list = []
|
| 107 |
+
for i in range(batch_size):
|
| 108 |
+
predictions_list.append(self.decode_single(all_cls_scores[i], all_bbox_preds[i]))
|
| 109 |
+
|
| 110 |
+
return predictions_list
|
models/bbox/match_costs/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .match_cost import BBox3DL1Cost
|
| 2 |
+
|
| 3 |
+
__all__ = ['BBox3DL1Cost']
|
models/bbox/match_costs/match_cost.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from mmdet.core.bbox.match_costs.builder import MATCH_COST
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@MATCH_COST.register_module()
|
| 6 |
+
class BBox3DL1Cost(object):
|
| 7 |
+
"""BBox3DL1Cost.
|
| 8 |
+
Args:
|
| 9 |
+
weight (int | float, optional): loss_weight
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, weight=1.0):
|
| 13 |
+
self.weight = weight
|
| 14 |
+
|
| 15 |
+
def __call__(self, bbox_pred, gt_bboxes):
|
| 16 |
+
"""
|
| 17 |
+
Args:
|
| 18 |
+
bbox_pred (Tensor): Predicted boxes with normalized coordinates
|
| 19 |
+
(cx, cy, w, h), which are all in range [0, 1]. Shape
|
| 20 |
+
[num_query, 4].
|
| 21 |
+
gt_bboxes (Tensor): Ground truth boxes with normalized
|
| 22 |
+
coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
|
| 23 |
+
Returns:
|
| 24 |
+
torch.Tensor: bbox_cost value with weight
|
| 25 |
+
"""
|
| 26 |
+
bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1)
|
| 27 |
+
return bbox_cost * self.weight
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@MATCH_COST.register_module()
|
| 31 |
+
class BBoxBEVL1Cost(object):
|
| 32 |
+
def __init__(self, weight, pc_range):
|
| 33 |
+
self.weight = weight
|
| 34 |
+
self.pc_range = pc_range
|
| 35 |
+
|
| 36 |
+
def __call__(self, bboxes, gt_bboxes):
|
| 37 |
+
pc_start = bboxes.new(self.pc_range[0:2])
|
| 38 |
+
pc_range = bboxes.new(self.pc_range[3:5]) - bboxes.new(self.pc_range[0:2])
|
| 39 |
+
# normalize the box center to [0, 1]
|
| 40 |
+
normalized_bboxes_xy = (bboxes[:, :2] - pc_start) / pc_range
|
| 41 |
+
normalized_gt_bboxes_xy = (gt_bboxes[:, :2] - pc_start) / pc_range
|
| 42 |
+
reg_cost = torch.cdist(normalized_bboxes_xy, normalized_gt_bboxes_xy, p=1)
|
| 43 |
+
return reg_cost * self.weight
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@MATCH_COST.register_module()
|
| 47 |
+
class IoU3DCost(object):
|
| 48 |
+
def __init__(self, weight):
|
| 49 |
+
self.weight = weight
|
| 50 |
+
|
| 51 |
+
def __call__(self, iou):
|
| 52 |
+
iou_cost = - iou
|
| 53 |
+
return iou_cost * self.weight
|
models/bbox/utils.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def normalize_bbox(bboxes):
|
| 5 |
+
cx = bboxes[..., 0:1]
|
| 6 |
+
cy = bboxes[..., 1:2]
|
| 7 |
+
cz = bboxes[..., 2:3]
|
| 8 |
+
w = bboxes[..., 3:4].log()
|
| 9 |
+
l = bboxes[..., 4:5].log()
|
| 10 |
+
h = bboxes[..., 5:6].log()
|
| 11 |
+
rot = bboxes[..., 6:7]
|
| 12 |
+
|
| 13 |
+
if bboxes.size(-1) > 7:
|
| 14 |
+
vx = bboxes[..., 7:8]
|
| 15 |
+
vy = bboxes[..., 8:9]
|
| 16 |
+
out = torch.cat([cx, cy, w, l, cz, h, rot.sin(), rot.cos(), vx, vy], dim=-1)
|
| 17 |
+
else:
|
| 18 |
+
out = torch.cat([cx, cy, w, l, cz, h, rot.sin(), rot.cos()], dim=-1)
|
| 19 |
+
|
| 20 |
+
return out
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def denormalize_bbox(normalized_bboxes):
|
| 24 |
+
rot_sin = normalized_bboxes[..., 6:7]
|
| 25 |
+
rot_cos = normalized_bboxes[..., 7:8]
|
| 26 |
+
rot = torch.atan2(rot_sin, rot_cos)
|
| 27 |
+
|
| 28 |
+
cx = normalized_bboxes[..., 0:1]
|
| 29 |
+
cy = normalized_bboxes[..., 1:2]
|
| 30 |
+
cz = normalized_bboxes[..., 4:5]
|
| 31 |
+
|
| 32 |
+
w = normalized_bboxes[..., 2:3].exp()
|
| 33 |
+
l = normalized_bboxes[..., 3:4].exp()
|
| 34 |
+
h = normalized_bboxes[..., 5:6].exp()
|
| 35 |
+
|
| 36 |
+
if normalized_bboxes.size(-1) > 8:
|
| 37 |
+
vx = normalized_bboxes[..., 8:9]
|
| 38 |
+
vy = normalized_bboxes[..., 9:10]
|
| 39 |
+
out = torch.cat([cx, cy, cz, w, l, h, rot, vx, vy], dim=-1)
|
| 40 |
+
else:
|
| 41 |
+
out = torch.cat([cx, cy, cz, w, l, h, rot], dim=-1)
|
| 42 |
+
|
| 43 |
+
return out
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def encode_bbox(bboxes, pc_range=None):
|
| 47 |
+
xyz = bboxes[..., 0:3].clone()
|
| 48 |
+
wlh = bboxes[..., 3:6].log()
|
| 49 |
+
rot = bboxes[..., 6:7]
|
| 50 |
+
|
| 51 |
+
if pc_range is not None:
|
| 52 |
+
xyz[..., 0] = (xyz[..., 0] - pc_range[0]) / (pc_range[3] - pc_range[0])
|
| 53 |
+
xyz[..., 1] = (xyz[..., 1] - pc_range[1]) / (pc_range[4] - pc_range[1])
|
| 54 |
+
xyz[..., 2] = (xyz[..., 2] - pc_range[2]) / (pc_range[5] - pc_range[2])
|
| 55 |
+
|
| 56 |
+
if bboxes.shape[-1] > 7:
|
| 57 |
+
vel = bboxes[..., 7:9].clone()
|
| 58 |
+
return torch.cat([xyz, wlh, rot.sin(), rot.cos(), vel], dim=-1)
|
| 59 |
+
else:
|
| 60 |
+
return torch.cat([xyz, wlh, rot.sin(), rot.cos()], dim=-1)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def decode_bbox(bboxes, pc_range=None):
|
| 64 |
+
xyz = bboxes[..., 0:3].clone()
|
| 65 |
+
wlh = bboxes[..., 3:6].exp()
|
| 66 |
+
rot = torch.atan2(bboxes[..., 6:7], bboxes[..., 7:8])
|
| 67 |
+
|
| 68 |
+
if pc_range is not None:
|
| 69 |
+
xyz[..., 0] = xyz[..., 0] * (pc_range[3] - pc_range[0]) + pc_range[0]
|
| 70 |
+
xyz[..., 1] = xyz[..., 1] * (pc_range[4] - pc_range[1]) + pc_range[1]
|
| 71 |
+
xyz[..., 2] = xyz[..., 2] * (pc_range[5] - pc_range[2]) + pc_range[2]
|
| 72 |
+
|
| 73 |
+
if bboxes.shape[-1] > 8:
|
| 74 |
+
vel = bboxes[..., 8:10].clone()
|
| 75 |
+
return torch.cat([xyz, wlh, rot, vel], dim=-1)
|
| 76 |
+
else:
|
| 77 |
+
return torch.cat([xyz, wlh, rot], dim=-1)
|
models/checkpoint.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import warnings
|
| 5 |
+
import weakref
|
| 6 |
+
from typing import Any, Iterable, List, Tuple
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"checkpoint", "checkpoint_sequential", "CheckpointFunction",
|
| 10 |
+
"check_backward_validity", "detach_variable", "get_device_states",
|
| 11 |
+
"set_device_states",
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
|
| 15 |
+
if isinstance(inputs, tuple):
|
| 16 |
+
out = []
|
| 17 |
+
for inp in inputs:
|
| 18 |
+
if not isinstance(inp, torch.Tensor):
|
| 19 |
+
out.append(inp)
|
| 20 |
+
continue
|
| 21 |
+
|
| 22 |
+
x = inp.detach()
|
| 23 |
+
x.requires_grad = inp.requires_grad
|
| 24 |
+
out.append(x)
|
| 25 |
+
return tuple(out)
|
| 26 |
+
else:
|
| 27 |
+
raise RuntimeError(
|
| 28 |
+
"Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def check_backward_validity(inputs: Iterable[Any]) -> None:
|
| 32 |
+
if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
|
| 33 |
+
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# We can't know if the run_fn will internally move some args to different devices,
|
| 37 |
+
# which would require logic to preserve rng states for those devices as well.
|
| 38 |
+
# We could paranoically stash and restore ALL the rng states for all visible devices,
|
| 39 |
+
# but that seems very wasteful for most cases. Compromise: Stash the RNG state for
|
| 40 |
+
# the device of all Tensor args.
|
| 41 |
+
#
|
| 42 |
+
# To consider: maybe get_device_states and set_device_states should reside in torch/random.py?
|
| 43 |
+
def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
|
| 44 |
+
# This will not error out if "arg" is a CPU tensor or a non-tensor type because
|
| 45 |
+
# the conditionals short-circuit.
|
| 46 |
+
fwd_gpu_devices = list({arg.get_device() for arg in args
|
| 47 |
+
if isinstance(arg, torch.Tensor) and arg.is_cuda})
|
| 48 |
+
|
| 49 |
+
fwd_gpu_states = []
|
| 50 |
+
for device in fwd_gpu_devices:
|
| 51 |
+
with torch.cuda.device(device):
|
| 52 |
+
fwd_gpu_states.append(torch.cuda.get_rng_state())
|
| 53 |
+
|
| 54 |
+
return fwd_gpu_devices, fwd_gpu_states
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def set_device_states(devices, states) -> None:
|
| 58 |
+
for device, state in zip(devices, states):
|
| 59 |
+
with torch.cuda.device(device):
|
| 60 |
+
torch.cuda.set_rng_state(state)
|
| 61 |
+
|
| 62 |
+
def _get_autocast_kwargs():
|
| 63 |
+
gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
|
| 64 |
+
"dtype": torch.get_autocast_gpu_dtype(),
|
| 65 |
+
"cache_enabled": torch.is_autocast_cache_enabled()}
|
| 66 |
+
|
| 67 |
+
cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(),
|
| 68 |
+
"dtype": torch.get_autocast_cpu_dtype(),
|
| 69 |
+
"cache_enabled": torch.is_autocast_cache_enabled()}
|
| 70 |
+
|
| 71 |
+
return gpu_autocast_kwargs, cpu_autocast_kwargs
|
| 72 |
+
|
| 73 |
+
class CheckpointFunction(torch.autograd.Function):
|
| 74 |
+
|
| 75 |
+
@staticmethod
|
| 76 |
+
def forward(ctx, run_function, preserve_rng_state, *args):
|
| 77 |
+
check_backward_validity(args)
|
| 78 |
+
ctx.run_function = run_function
|
| 79 |
+
ctx.preserve_rng_state = preserve_rng_state
|
| 80 |
+
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
| 81 |
+
ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs()
|
| 82 |
+
if preserve_rng_state:
|
| 83 |
+
ctx.fwd_cpu_state = torch.get_rng_state()
|
| 84 |
+
# Don't eagerly initialize the cuda context by accident.
|
| 85 |
+
# (If the user intends that the context is initialized later, within their
|
| 86 |
+
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
|
| 87 |
+
# we have no way to anticipate this will happen before we run the function.)
|
| 88 |
+
ctx.had_cuda_in_fwd = False
|
| 89 |
+
if torch.cuda._initialized:
|
| 90 |
+
ctx.had_cuda_in_fwd = True
|
| 91 |
+
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
|
| 92 |
+
|
| 93 |
+
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
| 94 |
+
# to be filled out during the backward.
|
| 95 |
+
ctx.inputs = []
|
| 96 |
+
ctx.tensor_indices = []
|
| 97 |
+
tensor_inputs = []
|
| 98 |
+
for i, arg in enumerate(args):
|
| 99 |
+
if torch.is_tensor(arg):
|
| 100 |
+
tensor_inputs.append(arg)
|
| 101 |
+
ctx.tensor_indices.append(i)
|
| 102 |
+
ctx.inputs.append(None)
|
| 103 |
+
else:
|
| 104 |
+
ctx.inputs.append(arg)
|
| 105 |
+
|
| 106 |
+
ctx.save_for_backward(*tensor_inputs)
|
| 107 |
+
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
outputs = run_function(*args)
|
| 110 |
+
return outputs
|
| 111 |
+
|
| 112 |
+
@staticmethod
|
| 113 |
+
def backward(ctx, *args):
|
| 114 |
+
if not torch.autograd._is_checkpoint_valid():
|
| 115 |
+
raise RuntimeError(
|
| 116 |
+
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
|
| 117 |
+
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
|
| 118 |
+
" argument.")
|
| 119 |
+
# Copy the list to avoid modifying original list.
|
| 120 |
+
inputs = list(ctx.inputs)
|
| 121 |
+
tensor_indices = ctx.tensor_indices
|
| 122 |
+
tensors = ctx.saved_tensors
|
| 123 |
+
|
| 124 |
+
# Fill in inputs with appropriate saved tensors.
|
| 125 |
+
for i, idx in enumerate(tensor_indices):
|
| 126 |
+
inputs[idx] = tensors[i]
|
| 127 |
+
|
| 128 |
+
# Stash the surrounding rng state, and mimic the state that was
|
| 129 |
+
# present at this time during forward. Restore the surrounding state
|
| 130 |
+
# when we're done.
|
| 131 |
+
rng_devices = []
|
| 132 |
+
if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
|
| 133 |
+
rng_devices = ctx.fwd_gpu_devices
|
| 134 |
+
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
|
| 135 |
+
if ctx.preserve_rng_state:
|
| 136 |
+
torch.set_rng_state(ctx.fwd_cpu_state)
|
| 137 |
+
if ctx.had_cuda_in_fwd:
|
| 138 |
+
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
|
| 139 |
+
detached_inputs = detach_variable(tuple(inputs))
|
| 140 |
+
with torch.enable_grad(), \
|
| 141 |
+
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
|
| 142 |
+
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
| 143 |
+
outputs = ctx.run_function(*detached_inputs)
|
| 144 |
+
|
| 145 |
+
if isinstance(outputs, torch.Tensor):
|
| 146 |
+
outputs = (outputs,)
|
| 147 |
+
|
| 148 |
+
# run backward() with only tensor that requires grad
|
| 149 |
+
outputs_with_grad = []
|
| 150 |
+
args_with_grad = []
|
| 151 |
+
for i in range(len(outputs)):
|
| 152 |
+
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
|
| 153 |
+
outputs_with_grad.append(outputs[i])
|
| 154 |
+
args_with_grad.append(args[i])
|
| 155 |
+
if len(outputs_with_grad) == 0:
|
| 156 |
+
raise RuntimeError(
|
| 157 |
+
"none of output has requires_grad=True,"
|
| 158 |
+
" this checkpoint() is not necessary")
|
| 159 |
+
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
| 160 |
+
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
|
| 161 |
+
for inp in detached_inputs)
|
| 162 |
+
|
| 163 |
+
return (None, None) + grads
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
|
| 167 |
+
r"""Checkpoint a model or part of the model
|
| 168 |
+
|
| 169 |
+
Checkpointing works by trading compute for memory. Rather than storing all
|
| 170 |
+
intermediate activations of the entire computation graph for computing
|
| 171 |
+
backward, the checkpointed part does **not** save intermediate activations,
|
| 172 |
+
and instead recomputes them in backward pass. It can be applied on any part
|
| 173 |
+
of a model.
|
| 174 |
+
|
| 175 |
+
Specifically, in the forward pass, :attr:`function` will run in
|
| 176 |
+
:func:`torch.no_grad` manner, i.e., not storing the intermediate
|
| 177 |
+
activations. Instead, the forward pass saves the inputs tuple and the
|
| 178 |
+
:attr:`function` parameter. In the backwards pass, the saved inputs and
|
| 179 |
+
:attr:`function` is retrieved, and the forward pass is computed on
|
| 180 |
+
:attr:`function` again, now tracking the intermediate activations, and then
|
| 181 |
+
the gradients are calculated using these activation values.
|
| 182 |
+
|
| 183 |
+
The output of :attr:`function` can contain non-Tensor values and gradient
|
| 184 |
+
recording is only performed for the Tensor values. Note that if the output
|
| 185 |
+
consists of nested structures (ex: custom objects, lists, dicts etc.)
|
| 186 |
+
consisting of Tensors, these Tensors nested in custom structures will not
|
| 187 |
+
be considered as part of autograd.
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
.. warning::
|
| 191 |
+
If :attr:`function` invocation during backward does anything different
|
| 192 |
+
than the one during forward, e.g., due to some global variable, the
|
| 193 |
+
checkpointed version won't be equivalent, and unfortunately it can't be
|
| 194 |
+
detected.
|
| 195 |
+
|
| 196 |
+
.. warning::
|
| 197 |
+
If ``use_reentrant=True`` is specified, then if the checkpointed segment
|
| 198 |
+
contains tensors detached from the computational graph by `detach()` or
|
| 199 |
+
`torch.no_grad()`, the backward pass will raise an error. This is
|
| 200 |
+
because `checkpoint` makes all the outputs require gradients which
|
| 201 |
+
causes issues when a tensor is defined to have no gradient in the model.
|
| 202 |
+
To circumvent this, detach the tensors outside of the `checkpoint`
|
| 203 |
+
function. Note that the checkpointed segment can contain tensors
|
| 204 |
+
detached from the computational graph if ``use_reentrant=False`` is
|
| 205 |
+
specified.
|
| 206 |
+
|
| 207 |
+
.. warning::
|
| 208 |
+
If ``use_reentrant=True`` is specified, at least one of the inputs needs
|
| 209 |
+
to have :code:`requires_grad=True` if grads are needed for model inputs,
|
| 210 |
+
otherwise the checkpointed part of the model won't have gradients. At
|
| 211 |
+
least one of the outputs needs to have :code:`requires_grad=True` as
|
| 212 |
+
well. Note that this does not apply if ``use_reentrant=False`` is
|
| 213 |
+
specified.
|
| 214 |
+
|
| 215 |
+
.. warning::
|
| 216 |
+
If ``use_reentrant=True`` is specified, checkpointing currently only
|
| 217 |
+
supports :func:`torch.autograd.backward` and only if its `inputs`
|
| 218 |
+
argument is not passed. :func:`torch.autograd.grad`
|
| 219 |
+
is not supported. If ``use_reentrant=False`` is specified, checkpointing
|
| 220 |
+
will work with :func:`torch.autograd.grad`.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
function: describes what to run in the forward pass of the model or
|
| 224 |
+
part of the model. It should also know how to handle the inputs
|
| 225 |
+
passed as the tuple. For example, in LSTM, if user passes
|
| 226 |
+
``(activation, hidden)``, :attr:`function` should correctly use the
|
| 227 |
+
first input as ``activation`` and the second input as ``hidden``
|
| 228 |
+
preserve_rng_state(bool, optional): Omit stashing and restoring
|
| 229 |
+
the RNG state during each checkpoint.
|
| 230 |
+
Default: ``True``
|
| 231 |
+
use_reentrant(bool, optional): Use checkpointing
|
| 232 |
+
implementation that requires re-entrant autograd.
|
| 233 |
+
If ``use_reentrant=False`` is specified, ``checkpoint`` will use an
|
| 234 |
+
implementation that does not require re-entrant autograd. This
|
| 235 |
+
allows ``checkpoint`` to support additional functionality, such as
|
| 236 |
+
working as expected with ``torch.autograd.grad`` and support for
|
| 237 |
+
keyword arguments input into the checkpointed function. Note that future
|
| 238 |
+
versions of PyTorch will default to ``use_reentrant=False``.
|
| 239 |
+
Default: ``True``
|
| 240 |
+
args: tuple containing inputs to the :attr:`function`
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
Output of running :attr:`function` on :attr:`*args`
|
| 244 |
+
"""
|
| 245 |
+
# Hack to mix *args with **kwargs in a python 2.7-compliant way
|
| 246 |
+
preserve = kwargs.pop('preserve_rng_state', True)
|
| 247 |
+
if kwargs and use_reentrant:
|
| 248 |
+
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
|
| 249 |
+
|
| 250 |
+
if use_reentrant:
|
| 251 |
+
return CheckpointFunction.apply(function, preserve, *args)
|
| 252 |
+
else:
|
| 253 |
+
return _checkpoint_without_reentrant(
|
| 254 |
+
function,
|
| 255 |
+
preserve,
|
| 256 |
+
*args,
|
| 257 |
+
**kwargs,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def checkpoint_sequential(functions, segments, input, use_reentrant=True, **kwargs):
|
| 262 |
+
r"""A helper function for checkpointing sequential models.
|
| 263 |
+
|
| 264 |
+
Sequential models execute a list of modules/functions in order
|
| 265 |
+
(sequentially). Therefore, we can divide such a model in various segments
|
| 266 |
+
and checkpoint each segment. All segments except the last will run in
|
| 267 |
+
:func:`torch.no_grad` manner, i.e., not storing the intermediate
|
| 268 |
+
activations. The inputs of each checkpointed segment will be saved for
|
| 269 |
+
re-running the segment in the backward pass.
|
| 270 |
+
|
| 271 |
+
See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
|
| 272 |
+
|
| 273 |
+
.. warning::
|
| 274 |
+
Checkpointing currently only supports :func:`torch.autograd.backward`
|
| 275 |
+
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
|
| 276 |
+
is not supported.
|
| 277 |
+
|
| 278 |
+
.. warning:
|
| 279 |
+
At least one of the inputs needs to have :code:`requires_grad=True` if
|
| 280 |
+
grads are needed for model inputs, otherwise the checkpointed part of the
|
| 281 |
+
model won't have gradients.
|
| 282 |
+
|
| 283 |
+
.. warning:
|
| 284 |
+
Since PyTorch 1.4, it allows only one Tensor as the input and
|
| 285 |
+
intermediate outputs, just like :class:`torch.nn.Sequential`.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
functions: A :class:`torch.nn.Sequential` or the list of modules or
|
| 289 |
+
functions (comprising the model) to run sequentially.
|
| 290 |
+
segments: Number of chunks to create in the model
|
| 291 |
+
input: A Tensor that is input to :attr:`functions`
|
| 292 |
+
preserve_rng_state(bool, optional): Omit stashing and restoring
|
| 293 |
+
the RNG state during each checkpoint.
|
| 294 |
+
Default: ``True``
|
| 295 |
+
use_reentrant(bool, optional): Use checkpointing
|
| 296 |
+
implementation that requires re-entrant autograd.
|
| 297 |
+
If ``use_reentrant=False`` is specified, ``checkpoint`` will use an
|
| 298 |
+
implementation that does not require re-entrant autograd. This
|
| 299 |
+
allows ``checkpoint`` to support additional functionality, such as
|
| 300 |
+
working as expected with ``torch.autograd.grad`` and support for
|
| 301 |
+
keyword arguments input into the checkpointed function.
|
| 302 |
+
Default: ``True``
|
| 303 |
+
|
| 304 |
+
Returns:
|
| 305 |
+
Output of running :attr:`functions` sequentially on :attr:`*inputs`
|
| 306 |
+
|
| 307 |
+
Example:
|
| 308 |
+
>>> # xdoctest: +SKIP("stub")
|
| 309 |
+
>>> model = nn.Sequential(...)
|
| 310 |
+
>>> input_var = checkpoint_sequential(model, chunks, input_var)
|
| 311 |
+
"""
|
| 312 |
+
# Hack for keyword-only parameter in a python 2.7-compliant way
|
| 313 |
+
preserve = kwargs.pop('preserve_rng_state', True)
|
| 314 |
+
if kwargs:
|
| 315 |
+
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
|
| 316 |
+
|
| 317 |
+
def run_function(start, end, functions):
|
| 318 |
+
def forward(input):
|
| 319 |
+
for j in range(start, end + 1):
|
| 320 |
+
input = functions[j](input)
|
| 321 |
+
return input
|
| 322 |
+
return forward
|
| 323 |
+
|
| 324 |
+
if isinstance(functions, torch.nn.Sequential):
|
| 325 |
+
functions = list(functions.children())
|
| 326 |
+
|
| 327 |
+
segment_size = len(functions) // segments
|
| 328 |
+
# the last chunk has to be non-volatile
|
| 329 |
+
end = -1
|
| 330 |
+
for start in range(0, segment_size * (segments - 1), segment_size):
|
| 331 |
+
end = start + segment_size - 1
|
| 332 |
+
input = checkpoint(
|
| 333 |
+
run_function(start, end, functions),
|
| 334 |
+
input,
|
| 335 |
+
use_reentrant=use_reentrant,
|
| 336 |
+
preserve_rng_state=preserve
|
| 337 |
+
)
|
| 338 |
+
return run_function(end + 1, len(functions) - 1, functions)(input)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def _checkpoint_without_reentrant(function, preserve_rng_state=True, *args, **kwargs):
|
| 342 |
+
"""Checkpointining without re-entrant autograd
|
| 343 |
+
Args:
|
| 344 |
+
function: describes what to run in the forward pass of the model or
|
| 345 |
+
part of the model. It should also know how to handle the inputs
|
| 346 |
+
passed as the tuple. For example, in LSTM, if user passes
|
| 347 |
+
``(activation, hidden)``, :attr:`function` should correctly use the
|
| 348 |
+
first input as ``activation`` and the second input as ``hidden``
|
| 349 |
+
preserve_rng_state(bool, optional): Omit stashing and restoring
|
| 350 |
+
the RNG state during each checkpoint.
|
| 351 |
+
Default: ``True``
|
| 352 |
+
*args: Arguments to pass in to the given ``function``.
|
| 353 |
+
**kwargs: Keyword arguments to pass into the given ``function``.
|
| 354 |
+
"""
|
| 355 |
+
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
| 356 |
+
gpu_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs()
|
| 357 |
+
|
| 358 |
+
if preserve_rng_state:
|
| 359 |
+
fwd_cpu_state = torch.get_rng_state()
|
| 360 |
+
# Don't eagerly initialize the cuda context by accident.
|
| 361 |
+
# (If the user intends that the context is initialized later, within their
|
| 362 |
+
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
|
| 363 |
+
# we have no way to anticipate this will happen before we run the function.
|
| 364 |
+
# If they do so, we raise an error.)
|
| 365 |
+
had_cuda_in_fwd = False
|
| 366 |
+
if torch.cuda._initialized:
|
| 367 |
+
had_cuda_in_fwd = True
|
| 368 |
+
fwd_gpu_devices, fwd_gpu_states = get_device_states(*args)
|
| 369 |
+
|
| 370 |
+
# Custom class to be able to take weak references
|
| 371 |
+
class Holder():
|
| 372 |
+
pass
|
| 373 |
+
# The Holder object for each of the saved object is saved directly on the
|
| 374 |
+
# SavedVariable and is cleared when reset_data() is called on it. We MUST make
|
| 375 |
+
# sure that this is the only object having an owning reference to ensure that
|
| 376 |
+
# the Tensor stored in storage is deleted as soon as the corresponding SavedVariable
|
| 377 |
+
# data is cleared.
|
| 378 |
+
storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
|
| 379 |
+
weak_holder_list = []
|
| 380 |
+
|
| 381 |
+
def pack(x):
|
| 382 |
+
# TODO(varal7): Instead of returning abstract object, we can return things metadata (such as
|
| 383 |
+
# size, device, ...) to catch certain cases of undeterministic behavior of the forward
|
| 384 |
+
res = Holder()
|
| 385 |
+
weak_holder_list.append(weakref.ref(res))
|
| 386 |
+
return res
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def unpack(x):
|
| 390 |
+
unpack_counter = 0
|
| 391 |
+
if len(storage) == 0:
|
| 392 |
+
def inner_pack(inner):
|
| 393 |
+
nonlocal unpack_counter
|
| 394 |
+
unpack_counter += 1
|
| 395 |
+
# If the holder went out of scope, the SavedVariable is dead and so
|
| 396 |
+
# the value will never be read from the storage. Skip filling it.
|
| 397 |
+
if weak_holder_list[unpack_counter - 1]() is None:
|
| 398 |
+
return
|
| 399 |
+
# Use detach here to ensure we don't keep the temporary autograd
|
| 400 |
+
# graph created during the second forward
|
| 401 |
+
storage[weak_holder_list[unpack_counter - 1]()] = inner.detach()
|
| 402 |
+
return
|
| 403 |
+
|
| 404 |
+
def inner_unpack(packed):
|
| 405 |
+
raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.")
|
| 406 |
+
|
| 407 |
+
# Stash the surrounding rng state, and mimic the state that was
|
| 408 |
+
# present at this time during forward. Restore the surrounding state
|
| 409 |
+
# when we're done.
|
| 410 |
+
rng_devices = []
|
| 411 |
+
if preserve_rng_state and had_cuda_in_fwd:
|
| 412 |
+
rng_devices = fwd_gpu_devices
|
| 413 |
+
with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
|
| 414 |
+
if preserve_rng_state:
|
| 415 |
+
torch.set_rng_state(fwd_cpu_state)
|
| 416 |
+
if had_cuda_in_fwd:
|
| 417 |
+
set_device_states(fwd_gpu_devices, fwd_gpu_states)
|
| 418 |
+
|
| 419 |
+
with torch.enable_grad(), \
|
| 420 |
+
torch.cuda.amp.autocast(**gpu_autocast_kwargs), \
|
| 421 |
+
torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
|
| 422 |
+
torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
|
| 423 |
+
_unused = function(*args, **kwargs)
|
| 424 |
+
|
| 425 |
+
if x not in storage:
|
| 426 |
+
raise RuntimeError(
|
| 427 |
+
"Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
|
| 428 |
+
" recomputation being triggered in between, this is not currently supported. Please"
|
| 429 |
+
" open an issue with details on your use case so that we can prioritize adding this."
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
return storage[x]
|
| 433 |
+
|
| 434 |
+
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
| 435 |
+
output = function(*args, **kwargs)
|
| 436 |
+
if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd:
|
| 437 |
+
# Cuda was not initialized before running the forward, so we didn't
|
| 438 |
+
# stash the CUDA state.
|
| 439 |
+
raise RuntimeError(
|
| 440 |
+
"PyTorch's CUDA state was initialized in the forward pass "
|
| 441 |
+
"of a Checkpoint, which is not allowed. Please open an issue "
|
| 442 |
+
"if you need this feature.")
|
| 443 |
+
|
| 444 |
+
return output
|
models/csrc/__init__.py
ADDED
|
File without changes
|
models/csrc/msmv_sampling/msmv_sampling.cpp
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "msmv_sampling.h"
|
| 2 |
+
|
| 3 |
+
#define MAX_POINT 32
|
| 4 |
+
|
| 5 |
+
void ms_deformable_im2col_cuda_c2345(
|
| 6 |
+
const float* feat_c2,
|
| 7 |
+
const float* feat_c3,
|
| 8 |
+
const float* feat_c4,
|
| 9 |
+
const float* feat_c5,
|
| 10 |
+
const int h_c2, const int w_c2,
|
| 11 |
+
const int h_c3, const int w_c3,
|
| 12 |
+
const int h_c4, const int w_c4,
|
| 13 |
+
const int h_c5, const int w_c5,
|
| 14 |
+
const float* data_sampling_loc,
|
| 15 |
+
const float* data_attn_weight,
|
| 16 |
+
const int batch_size,
|
| 17 |
+
const int channels,
|
| 18 |
+
const int num_views,
|
| 19 |
+
const int num_query,
|
| 20 |
+
const int num_point,
|
| 21 |
+
float* data_col
|
| 22 |
+
);
|
| 23 |
+
|
| 24 |
+
void ms_deformable_im2col_cuda_c23456(
|
| 25 |
+
const float* feat_c2,
|
| 26 |
+
const float* feat_c3,
|
| 27 |
+
const float* feat_c4,
|
| 28 |
+
const float* feat_c5,
|
| 29 |
+
const float* feat_c6,
|
| 30 |
+
const int h_c2, const int w_c2,
|
| 31 |
+
const int h_c3, const int w_c3,
|
| 32 |
+
const int h_c4, const int w_c4,
|
| 33 |
+
const int h_c5, const int w_c5,
|
| 34 |
+
const int h_c6, const int w_c6,
|
| 35 |
+
const float* data_sampling_loc,
|
| 36 |
+
const float* data_attn_weight,
|
| 37 |
+
const int batch_size,
|
| 38 |
+
const int channels,
|
| 39 |
+
const int num_views,
|
| 40 |
+
const int num_query,
|
| 41 |
+
const int num_point,
|
| 42 |
+
float* data_col
|
| 43 |
+
);
|
| 44 |
+
|
| 45 |
+
void ms_deformable_col2im_cuda_c2345(
|
| 46 |
+
const float* grad_col,
|
| 47 |
+
const float* feat_c2,
|
| 48 |
+
const float* feat_c3,
|
| 49 |
+
const float* feat_c4,
|
| 50 |
+
const float* feat_c5,
|
| 51 |
+
const int h_c2, const int w_c2,
|
| 52 |
+
const int h_c3, const int w_c3,
|
| 53 |
+
const int h_c4, const int w_c4,
|
| 54 |
+
const int h_c5, const int w_c5,
|
| 55 |
+
const float* data_sampling_loc,
|
| 56 |
+
const float* data_attn_weight,
|
| 57 |
+
const int batch_size,
|
| 58 |
+
const int channels,
|
| 59 |
+
const int num_views,
|
| 60 |
+
const int num_query,
|
| 61 |
+
const int num_point,
|
| 62 |
+
float* grad_value_c2,
|
| 63 |
+
float* grad_value_c3,
|
| 64 |
+
float* grad_value_c4,
|
| 65 |
+
float* grad_value_c5,
|
| 66 |
+
float* grad_sampling_loc,
|
| 67 |
+
float* grad_attn_weight
|
| 68 |
+
);
|
| 69 |
+
|
| 70 |
+
void ms_deformable_col2im_cuda_c23456(
|
| 71 |
+
const float *grad_col,
|
| 72 |
+
const float *feat_c2,
|
| 73 |
+
const float *feat_c3,
|
| 74 |
+
const float *feat_c4,
|
| 75 |
+
const float *feat_c5,
|
| 76 |
+
const float *feat_c6,
|
| 77 |
+
const int h_c2, const int w_c2,
|
| 78 |
+
const int h_c3, const int w_c3,
|
| 79 |
+
const int h_c4, const int w_c4,
|
| 80 |
+
const int h_c5, const int w_c5,
|
| 81 |
+
const int h_c6, const int w_c6,
|
| 82 |
+
const float *data_sampling_loc,
|
| 83 |
+
const float *data_attn_weight,
|
| 84 |
+
const int batch_size,
|
| 85 |
+
const int channels,
|
| 86 |
+
const int num_views,
|
| 87 |
+
const int num_query,
|
| 88 |
+
const int num_point,
|
| 89 |
+
float *grad_value_c2,
|
| 90 |
+
float *grad_value_c3,
|
| 91 |
+
float *grad_value_c4,
|
| 92 |
+
float *grad_value_c5,
|
| 93 |
+
float *grad_value_c6,
|
| 94 |
+
float *grad_sampling_loc,
|
| 95 |
+
float *grad_attn_weight
|
| 96 |
+
);
|
| 97 |
+
|
| 98 |
+
at::Tensor ms_deform_attn_cuda_c2345_forward(
|
| 99 |
+
const at::Tensor& feat_c2, // [B, N, H, W, C]
|
| 100 |
+
const at::Tensor& feat_c3, // [B, N, H, W, C]
|
| 101 |
+
const at::Tensor& feat_c4, // [B, N, H, W, C]
|
| 102 |
+
const at::Tensor& feat_c5, // [B, N, H, W, C]
|
| 103 |
+
const at::Tensor& sampling_loc, // [B, Q, P, 3]
|
| 104 |
+
const at::Tensor& attn_weight // [B, Q, P, 4]
|
| 105 |
+
) {
|
| 106 |
+
AT_ASSERTM(feat_c2.is_contiguous(), "value tensor has to be contiguous");
|
| 107 |
+
AT_ASSERTM(feat_c3.is_contiguous(), "value tensor has to be contiguous");
|
| 108 |
+
AT_ASSERTM(feat_c4.is_contiguous(), "value tensor has to be contiguous");
|
| 109 |
+
AT_ASSERTM(feat_c5.is_contiguous(), "value tensor has to be contiguous");
|
| 110 |
+
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
| 111 |
+
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
| 112 |
+
|
| 113 |
+
AT_ASSERTM(feat_c2.is_cuda(), "value must be a CUDA tensor");
|
| 114 |
+
AT_ASSERTM(feat_c3.is_cuda(), "value must be a CUDA tensor");
|
| 115 |
+
AT_ASSERTM(feat_c4.is_cuda(), "value must be a CUDA tensor");
|
| 116 |
+
AT_ASSERTM(feat_c5.is_cuda(), "value must be a CUDA tensor");
|
| 117 |
+
AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
|
| 118 |
+
AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
|
| 119 |
+
|
| 120 |
+
const int batch_size = feat_c2.size(0);
|
| 121 |
+
const int num_views = feat_c2.size(1);
|
| 122 |
+
const int channels = feat_c2.size(4);
|
| 123 |
+
const int num_query = sampling_loc.size(1);
|
| 124 |
+
const int num_point = sampling_loc.size(2);
|
| 125 |
+
AT_ASSERTM(num_point <= MAX_POINT, "num_point exceed limits");
|
| 126 |
+
|
| 127 |
+
const int h_c2 = feat_c2.size(2);
|
| 128 |
+
const int w_c2 = feat_c2.size(3);
|
| 129 |
+
const int h_c3 = feat_c3.size(2);
|
| 130 |
+
const int w_c3 = feat_c3.size(3);
|
| 131 |
+
const int h_c4 = feat_c4.size(2);
|
| 132 |
+
const int w_c4 = feat_c4.size(3);
|
| 133 |
+
const int h_c5 = feat_c5.size(2);
|
| 134 |
+
const int w_c5 = feat_c5.size(3);
|
| 135 |
+
|
| 136 |
+
auto output = at::zeros({ batch_size, num_query, channels, num_point }, feat_c2.options());
|
| 137 |
+
ms_deformable_im2col_cuda_c2345(
|
| 138 |
+
feat_c2.data_ptr<float>(),
|
| 139 |
+
feat_c3.data_ptr<float>(),
|
| 140 |
+
feat_c4.data_ptr<float>(),
|
| 141 |
+
feat_c5.data_ptr<float>(),
|
| 142 |
+
h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5,
|
| 143 |
+
sampling_loc.data_ptr<float>(),
|
| 144 |
+
attn_weight.data_ptr<float>(),
|
| 145 |
+
batch_size, channels, num_views, num_query, num_point,
|
| 146 |
+
output.data_ptr<float>()
|
| 147 |
+
);
|
| 148 |
+
|
| 149 |
+
return output;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
at::Tensor ms_deform_attn_cuda_c23456_forward(
|
| 153 |
+
const at::Tensor& feat_c2, // [B, N, H, W, C]
|
| 154 |
+
const at::Tensor& feat_c3, // [B, N, H, W, C]
|
| 155 |
+
const at::Tensor& feat_c4, // [B, N, H, W, C]
|
| 156 |
+
const at::Tensor& feat_c5, // [B, N, H, W, C]
|
| 157 |
+
const at::Tensor& feat_c6, // [B, N, H, W, C]
|
| 158 |
+
const at::Tensor& sampling_loc, // [B, Q, P, 3]
|
| 159 |
+
const at::Tensor& attn_weight // [B, Q, P, 4]
|
| 160 |
+
) {
|
| 161 |
+
AT_ASSERTM(feat_c2.is_contiguous(), "value tensor has to be contiguous");
|
| 162 |
+
AT_ASSERTM(feat_c3.is_contiguous(), "value tensor has to be contiguous");
|
| 163 |
+
AT_ASSERTM(feat_c4.is_contiguous(), "value tensor has to be contiguous");
|
| 164 |
+
AT_ASSERTM(feat_c5.is_contiguous(), "value tensor has to be contiguous");
|
| 165 |
+
AT_ASSERTM(feat_c6.is_contiguous(), "value tensor has to be contiguous");
|
| 166 |
+
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
| 167 |
+
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
| 168 |
+
|
| 169 |
+
AT_ASSERTM(feat_c2.is_cuda(), "value must be a CUDA tensor");
|
| 170 |
+
AT_ASSERTM(feat_c3.is_cuda(), "value must be a CUDA tensor");
|
| 171 |
+
AT_ASSERTM(feat_c4.is_cuda(), "value must be a CUDA tensor");
|
| 172 |
+
AT_ASSERTM(feat_c5.is_cuda(), "value must be a CUDA tensor");
|
| 173 |
+
AT_ASSERTM(feat_c6.is_cuda(), "value must be a CUDA tensor");
|
| 174 |
+
AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
|
| 175 |
+
AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
|
| 176 |
+
|
| 177 |
+
const int batch_size = feat_c2.size(0);
|
| 178 |
+
const int num_views = feat_c2.size(1);
|
| 179 |
+
const int channels = feat_c2.size(4);
|
| 180 |
+
const int num_query = sampling_loc.size(1);
|
| 181 |
+
const int num_point = sampling_loc.size(2);
|
| 182 |
+
AT_ASSERTM(num_point <= MAX_POINT, "num_point exceed limits");
|
| 183 |
+
|
| 184 |
+
const int h_c2 = feat_c2.size(2);
|
| 185 |
+
const int w_c2 = feat_c2.size(3);
|
| 186 |
+
const int h_c3 = feat_c3.size(2);
|
| 187 |
+
const int w_c3 = feat_c3.size(3);
|
| 188 |
+
const int h_c4 = feat_c4.size(2);
|
| 189 |
+
const int w_c4 = feat_c4.size(3);
|
| 190 |
+
const int h_c5 = feat_c5.size(2);
|
| 191 |
+
const int w_c5 = feat_c5.size(3);
|
| 192 |
+
const int h_c6 = feat_c6.size(2);
|
| 193 |
+
const int w_c6 = feat_c6.size(3);
|
| 194 |
+
|
| 195 |
+
auto output = at::zeros({ batch_size, num_query, channels, num_point }, feat_c2.options());
|
| 196 |
+
ms_deformable_im2col_cuda_c23456(
|
| 197 |
+
feat_c2.data_ptr<float>(),
|
| 198 |
+
feat_c3.data_ptr<float>(),
|
| 199 |
+
feat_c4.data_ptr<float>(),
|
| 200 |
+
feat_c5.data_ptr<float>(),
|
| 201 |
+
feat_c6.data_ptr<float>(),
|
| 202 |
+
h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6,
|
| 203 |
+
sampling_loc.data_ptr<float>(),
|
| 204 |
+
attn_weight.data_ptr<float>(),
|
| 205 |
+
batch_size, channels, num_views, num_query, num_point,
|
| 206 |
+
output.data_ptr<float>()
|
| 207 |
+
);
|
| 208 |
+
|
| 209 |
+
return output;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
std::vector<at::Tensor> ms_deform_attn_cuda_c2345_backward(
|
| 213 |
+
const at::Tensor& grad_output,
|
| 214 |
+
const at::Tensor& feat_c2, // [B, N, H, W, C]
|
| 215 |
+
const at::Tensor& feat_c3, // [B, N, H, W, C]
|
| 216 |
+
const at::Tensor& feat_c4, // [B, N, H, W, C]
|
| 217 |
+
const at::Tensor& feat_c5, // [B, N, H, W, C]
|
| 218 |
+
const at::Tensor& sampling_loc, // [B, Q, P, 3]
|
| 219 |
+
const at::Tensor& attn_weight // [B, Q, P, 4]
|
| 220 |
+
) {
|
| 221 |
+
AT_ASSERTM(feat_c2.is_contiguous(), "value tensor has to be contiguous");
|
| 222 |
+
AT_ASSERTM(feat_c3.is_contiguous(), "value tensor has to be contiguous");
|
| 223 |
+
AT_ASSERTM(feat_c4.is_contiguous(), "value tensor has to be contiguous");
|
| 224 |
+
AT_ASSERTM(feat_c5.is_contiguous(), "value tensor has to be contiguous");
|
| 225 |
+
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
| 226 |
+
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
| 227 |
+
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
|
| 228 |
+
|
| 229 |
+
AT_ASSERTM(feat_c2.is_cuda(), "value must be a CUDA tensor");
|
| 230 |
+
AT_ASSERTM(feat_c3.is_cuda(), "value must be a CUDA tensor");
|
| 231 |
+
AT_ASSERTM(feat_c4.is_cuda(), "value must be a CUDA tensor");
|
| 232 |
+
AT_ASSERTM(feat_c5.is_cuda(), "value must be a CUDA tensor");
|
| 233 |
+
AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
|
| 234 |
+
AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
|
| 235 |
+
AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor");
|
| 236 |
+
|
| 237 |
+
const int batch_size = feat_c2.size(0);
|
| 238 |
+
const int num_views = feat_c2.size(1);
|
| 239 |
+
const int channels = feat_c2.size(4);
|
| 240 |
+
const int num_query = sampling_loc.size(1);
|
| 241 |
+
const int num_point = sampling_loc.size(2);
|
| 242 |
+
AT_ASSERTM(num_point <= MAX_POINT, "num_point exceed limits");
|
| 243 |
+
|
| 244 |
+
auto grad_value_c2 = at::zeros_like(feat_c2);
|
| 245 |
+
auto grad_value_c3 = at::zeros_like(feat_c3);
|
| 246 |
+
auto grad_value_c4 = at::zeros_like(feat_c4);
|
| 247 |
+
auto grad_value_c5 = at::zeros_like(feat_c5);
|
| 248 |
+
auto grad_sampling_loc = at::zeros_like(sampling_loc);
|
| 249 |
+
auto grad_attn_weight = at::zeros_like(attn_weight);
|
| 250 |
+
|
| 251 |
+
const int h_c2 = feat_c2.size(2);
|
| 252 |
+
const int w_c2 = feat_c2.size(3);
|
| 253 |
+
const int h_c3 = feat_c3.size(2);
|
| 254 |
+
const int w_c3 = feat_c3.size(3);
|
| 255 |
+
const int h_c4 = feat_c4.size(2);
|
| 256 |
+
const int w_c4 = feat_c4.size(3);
|
| 257 |
+
const int h_c5 = feat_c5.size(2);
|
| 258 |
+
const int w_c5 = feat_c5.size(3);
|
| 259 |
+
|
| 260 |
+
ms_deformable_col2im_cuda_c2345(
|
| 261 |
+
grad_output.data_ptr<float>(),
|
| 262 |
+
feat_c2.data_ptr<float>(),
|
| 263 |
+
feat_c3.data_ptr<float>(),
|
| 264 |
+
feat_c4.data_ptr<float>(),
|
| 265 |
+
feat_c5.data_ptr<float>(),
|
| 266 |
+
h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5,
|
| 267 |
+
sampling_loc.data_ptr<float>(),
|
| 268 |
+
attn_weight.data_ptr<float>(),
|
| 269 |
+
batch_size, channels, num_views, num_query, num_point,
|
| 270 |
+
grad_value_c2.data_ptr<float>(),
|
| 271 |
+
grad_value_c3.data_ptr<float>(),
|
| 272 |
+
grad_value_c4.data_ptr<float>(),
|
| 273 |
+
grad_value_c5.data_ptr<float>(),
|
| 274 |
+
grad_sampling_loc.data_ptr<float>(),
|
| 275 |
+
grad_attn_weight.data_ptr<float>()
|
| 276 |
+
);
|
| 277 |
+
|
| 278 |
+
return {
|
| 279 |
+
grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight
|
| 280 |
+
};
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
std::vector<at::Tensor> ms_deform_attn_cuda_c23456_backward(
|
| 284 |
+
const at::Tensor& grad_output,
|
| 285 |
+
const at::Tensor& feat_c2, // [B, N, H, W, C]
|
| 286 |
+
const at::Tensor& feat_c3, // [B, N, H, W, C]
|
| 287 |
+
const at::Tensor& feat_c4, // [B, N, H, W, C]
|
| 288 |
+
const at::Tensor& feat_c5, // [B, N, H, W, C]
|
| 289 |
+
const at::Tensor& feat_c6, // [B, N, H, W, C]
|
| 290 |
+
const at::Tensor& sampling_loc, // [B, Q, P, 3]
|
| 291 |
+
const at::Tensor& attn_weight // [B, Q, P, 4]
|
| 292 |
+
) {
|
| 293 |
+
AT_ASSERTM(feat_c2.is_contiguous(), "value tensor has to be contiguous");
|
| 294 |
+
AT_ASSERTM(feat_c3.is_contiguous(), "value tensor has to be contiguous");
|
| 295 |
+
AT_ASSERTM(feat_c4.is_contiguous(), "value tensor has to be contiguous");
|
| 296 |
+
AT_ASSERTM(feat_c5.is_contiguous(), "value tensor has to be contiguous");
|
| 297 |
+
AT_ASSERTM(feat_c6.is_contiguous(), "value tensor has to be contiguous");
|
| 298 |
+
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
| 299 |
+
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
| 300 |
+
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
|
| 301 |
+
|
| 302 |
+
AT_ASSERTM(feat_c2.is_cuda(), "value must be a CUDA tensor");
|
| 303 |
+
AT_ASSERTM(feat_c3.is_cuda(), "value must be a CUDA tensor");
|
| 304 |
+
AT_ASSERTM(feat_c4.is_cuda(), "value must be a CUDA tensor");
|
| 305 |
+
AT_ASSERTM(feat_c5.is_cuda(), "value must be a CUDA tensor");
|
| 306 |
+
AT_ASSERTM(feat_c6.is_cuda(), "value must be a CUDA tensor");
|
| 307 |
+
AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
|
| 308 |
+
AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
|
| 309 |
+
AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor");
|
| 310 |
+
|
| 311 |
+
const int batch_size = feat_c2.size(0);
|
| 312 |
+
const int num_views = feat_c2.size(1);
|
| 313 |
+
const int channels = feat_c2.size(4);
|
| 314 |
+
const int num_query = sampling_loc.size(1);
|
| 315 |
+
const int num_point = sampling_loc.size(2);
|
| 316 |
+
AT_ASSERTM(num_point <= MAX_POINT, "num_point exceed limits");
|
| 317 |
+
|
| 318 |
+
auto grad_value_c2 = at::zeros_like(feat_c2);
|
| 319 |
+
auto grad_value_c3 = at::zeros_like(feat_c3);
|
| 320 |
+
auto grad_value_c4 = at::zeros_like(feat_c4);
|
| 321 |
+
auto grad_value_c5 = at::zeros_like(feat_c5);
|
| 322 |
+
auto grad_value_c6 = at::zeros_like(feat_c6);
|
| 323 |
+
auto grad_sampling_loc = at::zeros_like(sampling_loc);
|
| 324 |
+
auto grad_attn_weight = at::zeros_like(attn_weight);
|
| 325 |
+
|
| 326 |
+
const int h_c2 = feat_c2.size(2);
|
| 327 |
+
const int w_c2 = feat_c2.size(3);
|
| 328 |
+
const int h_c3 = feat_c3.size(2);
|
| 329 |
+
const int w_c3 = feat_c3.size(3);
|
| 330 |
+
const int h_c4 = feat_c4.size(2);
|
| 331 |
+
const int w_c4 = feat_c4.size(3);
|
| 332 |
+
const int h_c5 = feat_c5.size(2);
|
| 333 |
+
const int w_c5 = feat_c5.size(3);
|
| 334 |
+
const int h_c6 = feat_c6.size(2);
|
| 335 |
+
const int w_c6 = feat_c6.size(3);
|
| 336 |
+
|
| 337 |
+
ms_deformable_col2im_cuda_c23456(
|
| 338 |
+
grad_output.data_ptr<float>(),
|
| 339 |
+
feat_c2.data_ptr<float>(),
|
| 340 |
+
feat_c3.data_ptr<float>(),
|
| 341 |
+
feat_c4.data_ptr<float>(),
|
| 342 |
+
feat_c5.data_ptr<float>(),
|
| 343 |
+
feat_c6.data_ptr<float>(),
|
| 344 |
+
h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6,
|
| 345 |
+
sampling_loc.data_ptr<float>(),
|
| 346 |
+
attn_weight.data_ptr<float>(),
|
| 347 |
+
batch_size, channels, num_views, num_query, num_point,
|
| 348 |
+
grad_value_c2.data_ptr<float>(),
|
| 349 |
+
grad_value_c3.data_ptr<float>(),
|
| 350 |
+
grad_value_c4.data_ptr<float>(),
|
| 351 |
+
grad_value_c5.data_ptr<float>(),
|
| 352 |
+
grad_value_c6.data_ptr<float>(),
|
| 353 |
+
grad_sampling_loc.data_ptr<float>(),
|
| 354 |
+
grad_attn_weight.data_ptr<float>()
|
| 355 |
+
);
|
| 356 |
+
|
| 357 |
+
return {
|
| 358 |
+
grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight
|
| 359 |
+
};
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
#ifdef TORCH_EXTENSION_NAME
|
| 363 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 364 |
+
m.def("_ms_deform_attn_cuda_c2345_forward", &ms_deform_attn_cuda_c2345_forward, "pass");
|
| 365 |
+
m.def("_ms_deform_attn_cuda_c2345_backward", &ms_deform_attn_cuda_c2345_backward, "pass");
|
| 366 |
+
m.def("_ms_deform_attn_cuda_c23456_forward", &ms_deform_attn_cuda_c23456_forward, "pass");
|
| 367 |
+
m.def("_ms_deform_attn_cuda_c23456_backward", &ms_deform_attn_cuda_c23456_backward, "pass");
|
| 368 |
+
}
|
| 369 |
+
#endif
|
models/csrc/msmv_sampling/msmv_sampling.h
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/extension.h>
|
| 4 |
+
|
| 5 |
+
at::Tensor ms_deform_attn_cuda_c2345_forward(
|
| 6 |
+
const at::Tensor& feat_c2, // [B, N, H, W, C]
|
| 7 |
+
const at::Tensor& feat_c3, // [B, N, H, W, C]
|
| 8 |
+
const at::Tensor& feat_c4, // [B, N, H, W, C]
|
| 9 |
+
const at::Tensor& feat_c5, // [B, N, H, W, C]
|
| 10 |
+
const at::Tensor& sampling_loc, // [B, Q, P, 3]
|
| 11 |
+
const at::Tensor& attn_weight // [B, Q, P, 4]
|
| 12 |
+
);
|
| 13 |
+
|
| 14 |
+
std::vector<at::Tensor> ms_deform_attn_cuda_c2345_backward(
|
| 15 |
+
const at::Tensor& feat_c2, // [B, N, H, W, C]
|
| 16 |
+
const at::Tensor& feat_c3, // [B, N, H, W, C]
|
| 17 |
+
const at::Tensor& feat_c4, // [B, N, H, W, C]
|
| 18 |
+
const at::Tensor& feat_c5, // [B, N, H, W, C]
|
| 19 |
+
const at::Tensor& sampling_loc, // [B, Q, P, 3]
|
| 20 |
+
const at::Tensor& attn_weight, // [B, Q, P, 4]
|
| 21 |
+
const at::Tensor& grad_output
|
| 22 |
+
);
|
| 23 |
+
|
| 24 |
+
at::Tensor ms_deform_attn_cuda_c23456_forward(
|
| 25 |
+
const at::Tensor& feat_c2, // [B, N, H, W, C]
|
| 26 |
+
const at::Tensor& feat_c3, // [B, N, H, W, C]
|
| 27 |
+
const at::Tensor& feat_c4, // [B, N, H, W, C]
|
| 28 |
+
const at::Tensor& feat_c5, // [B, N, H, W, C]
|
| 29 |
+
const at::Tensor& feat_c6, // [B, N, H, W, C]
|
| 30 |
+
const at::Tensor& sampling_loc, // [B, Q, P, 3]
|
| 31 |
+
const at::Tensor& attn_weight // [B, Q, P, 4]
|
| 32 |
+
);
|
| 33 |
+
|
| 34 |
+
std::vector<at::Tensor> ms_deform_attn_cuda_c23456_backward(
|
| 35 |
+
const at::Tensor& grad_output,
|
| 36 |
+
const at::Tensor& feat_c2, // [B, N, H, W, C]
|
| 37 |
+
const at::Tensor& feat_c3, // [B, N, H, W, C]
|
| 38 |
+
const at::Tensor& feat_c4, // [B, N, H, W, C]
|
| 39 |
+
const at::Tensor& feat_c5, // [B, N, H, W, C]
|
| 40 |
+
const at::Tensor& feat_c6, // [B, N, H, W, C]
|
| 41 |
+
const at::Tensor& sampling_loc, // [B, Q, P, 3]
|
| 42 |
+
const at::Tensor& attn_weight // [B, Q, P, 4]
|
| 43 |
+
);
|
models/csrc/msmv_sampling/msmv_sampling_backward.cu
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
* Modified from Deformable DETR
|
| 3 |
+
*/
|
| 4 |
+
|
| 5 |
+
#include <cstdio>
|
| 6 |
+
#include <iostream>
|
| 7 |
+
#include <algorithm>
|
| 8 |
+
#include <cstring>
|
| 9 |
+
#include <cuda_runtime.h>
|
| 10 |
+
#include <device_launch_parameters.h>
|
| 11 |
+
#include <torch/extension.h>
|
| 12 |
+
#include <ATen/ATen.h>
|
| 13 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 14 |
+
#include <THC/THCAtomics.cuh>
|
| 15 |
+
|
| 16 |
+
#define CUDA_KERNEL_LOOP(i, n) \
|
| 17 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
|
| 18 |
+
i < (n); \
|
| 19 |
+
i += blockDim.x * gridDim.x)
|
| 20 |
+
|
| 21 |
+
#define CUDA_NUM_THREADS 512
|
| 22 |
+
#define MAX_POINT 32
|
| 23 |
+
|
| 24 |
+
inline int GET_BLOCKS(const int N, const int num_threads)
|
| 25 |
+
{
|
| 26 |
+
return (N + num_threads - 1) / num_threads;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
__device__ void ms_deform_attn_col2im_bilinear(const float *&bottom_data,
|
| 30 |
+
const int &height, const int &width, const int &channels,
|
| 31 |
+
const float &h, const float &w, const int &c,
|
| 32 |
+
const float &top_grad,
|
| 33 |
+
const float &attn_weight,
|
| 34 |
+
const float *&grad_value,
|
| 35 |
+
float *&grad_sampling_loc,
|
| 36 |
+
float *&grad_attn_weight)
|
| 37 |
+
{
|
| 38 |
+
const int h_low = floor(h);
|
| 39 |
+
const int w_low = floor(w);
|
| 40 |
+
const int h_high = h_low + 1;
|
| 41 |
+
const int w_high = w_low + 1;
|
| 42 |
+
|
| 43 |
+
const float lh = h - h_low;
|
| 44 |
+
const float lw = w - w_low;
|
| 45 |
+
const float hh = 1 - lh, hw = 1 - lw;
|
| 46 |
+
|
| 47 |
+
const int w_stride = channels;
|
| 48 |
+
const int h_stride = width * w_stride;
|
| 49 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
| 50 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
| 51 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
| 52 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
| 53 |
+
|
| 54 |
+
const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 55 |
+
const float top_grad_value = top_grad * attn_weight;
|
| 56 |
+
float grad_h_weight = 0, grad_w_weight = 0;
|
| 57 |
+
|
| 58 |
+
float *grad_ptr;
|
| 59 |
+
|
| 60 |
+
float v1 = 0;
|
| 61 |
+
if (h_low >= 0 && w_low >= 0)
|
| 62 |
+
{
|
| 63 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + c;
|
| 64 |
+
grad_ptr = const_cast<float *>(grad_value + ptr1);
|
| 65 |
+
v1 = bottom_data[ptr1];
|
| 66 |
+
grad_h_weight -= hw * v1;
|
| 67 |
+
grad_w_weight -= hh * v1;
|
| 68 |
+
atomicAdd(grad_ptr, w1 * top_grad_value);
|
| 69 |
+
}
|
| 70 |
+
float v2 = 0;
|
| 71 |
+
if (h_low >= 0 && w_high <= width - 1)
|
| 72 |
+
{
|
| 73 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + c;
|
| 74 |
+
grad_ptr = const_cast<float *>(grad_value + ptr2);
|
| 75 |
+
v2 = bottom_data[ptr2];
|
| 76 |
+
grad_h_weight -= lw * v2;
|
| 77 |
+
grad_w_weight += hh * v2;
|
| 78 |
+
atomicAdd(grad_ptr, w2 * top_grad_value);
|
| 79 |
+
}
|
| 80 |
+
float v3 = 0;
|
| 81 |
+
if (h_high <= height - 1 && w_low >= 0)
|
| 82 |
+
{
|
| 83 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + c;
|
| 84 |
+
grad_ptr = const_cast<float *>(grad_value + ptr3);
|
| 85 |
+
v3 = bottom_data[ptr3];
|
| 86 |
+
grad_h_weight += hw * v3;
|
| 87 |
+
grad_w_weight -= lh * v3;
|
| 88 |
+
atomicAdd(grad_ptr, w3 * top_grad_value);
|
| 89 |
+
}
|
| 90 |
+
float v4 = 0;
|
| 91 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
| 92 |
+
{
|
| 93 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + c;
|
| 94 |
+
grad_ptr = const_cast<float *>(grad_value + ptr4);
|
| 95 |
+
v4 = bottom_data[ptr4];
|
| 96 |
+
grad_h_weight += lw * v4;
|
| 97 |
+
grad_w_weight += lh * v4;
|
| 98 |
+
atomicAdd(grad_ptr, w4 * top_grad_value);
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 102 |
+
atomicAdd(grad_attn_weight, top_grad * val);
|
| 103 |
+
atomicAdd(grad_sampling_loc, (width - 1) * grad_w_weight * top_grad_value);
|
| 104 |
+
atomicAdd(grad_sampling_loc + 1, (height - 1) * grad_h_weight * top_grad_value);
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
// global_memory_way
|
| 108 |
+
__global__ void ms_deformable_col2im_gpu_kernel_gm_c2345(
|
| 109 |
+
const float *grad_col,
|
| 110 |
+
const float *feat_c2,
|
| 111 |
+
const float *feat_c3,
|
| 112 |
+
const float *feat_c4,
|
| 113 |
+
const float *feat_c5,
|
| 114 |
+
const int h_c2, const int w_c2,
|
| 115 |
+
const int h_c3, const int w_c3,
|
| 116 |
+
const int h_c4, const int w_c4,
|
| 117 |
+
const int h_c5, const int w_c5,
|
| 118 |
+
const float *data_sampling_loc,
|
| 119 |
+
const float *data_attn_weight,
|
| 120 |
+
const int batch_size,
|
| 121 |
+
const int channels,
|
| 122 |
+
const int num_views,
|
| 123 |
+
const int num_query,
|
| 124 |
+
const int num_point,
|
| 125 |
+
float *grad_value_c2,
|
| 126 |
+
float *grad_value_c3,
|
| 127 |
+
float *grad_value_c4,
|
| 128 |
+
float *grad_value_c5,
|
| 129 |
+
float *grad_sampling_loc,
|
| 130 |
+
float *grad_attn_weight)
|
| 131 |
+
{
|
| 132 |
+
CUDA_KERNEL_LOOP(index, batch_size * num_query * channels * num_point)
|
| 133 |
+
{ // n: bs x query x channels
|
| 134 |
+
|
| 135 |
+
int _temp = index;
|
| 136 |
+
const int p_col = _temp % num_point;
|
| 137 |
+
_temp /= num_point;
|
| 138 |
+
const int c_col = _temp % channels;
|
| 139 |
+
_temp /= channels;
|
| 140 |
+
const int sampling_index = _temp;
|
| 141 |
+
_temp /= num_query;
|
| 142 |
+
const int b_col = _temp;
|
| 143 |
+
|
| 144 |
+
const float top_grad = grad_col[index];
|
| 145 |
+
|
| 146 |
+
// Sampling location in range [0, 1]
|
| 147 |
+
int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3;
|
| 148 |
+
const float loc_w = data_sampling_loc[data_loc_ptr];
|
| 149 |
+
const float loc_h = data_sampling_loc[data_loc_ptr + 1];
|
| 150 |
+
const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1));
|
| 151 |
+
|
| 152 |
+
// Attn weights
|
| 153 |
+
int data_weight_ptr = sampling_index * num_point * 4 + p_col * 4;
|
| 154 |
+
|
| 155 |
+
const float weight_c2 = data_attn_weight[data_weight_ptr];
|
| 156 |
+
const float weight_c3 = data_attn_weight[data_weight_ptr + 1];
|
| 157 |
+
const float weight_c4 = data_attn_weight[data_weight_ptr + 2];
|
| 158 |
+
const float weight_c5 = data_attn_weight[data_weight_ptr + 3];
|
| 159 |
+
|
| 160 |
+
// const float h_im = loc_h * spatial_h - 0.5; // align_corners = False
|
| 161 |
+
// const float w_im = loc_w * spatial_w - 0.5;
|
| 162 |
+
|
| 163 |
+
// C2 Feature
|
| 164 |
+
float h_im = loc_h * (h_c2 - 1); // align_corners = True
|
| 165 |
+
float w_im = loc_w * (w_c2 - 1);
|
| 166 |
+
|
| 167 |
+
float *grad_location_ptr = grad_sampling_loc + data_loc_ptr;
|
| 168 |
+
float *grad_weights_ptr = grad_attn_weight + data_weight_ptr;
|
| 169 |
+
|
| 170 |
+
if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2)
|
| 171 |
+
{
|
| 172 |
+
const float *feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;
|
| 173 |
+
const float *grad_c2_ptr = grad_value_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;
|
| 174 |
+
ms_deform_attn_col2im_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col,
|
| 175 |
+
top_grad, weight_c2,
|
| 176 |
+
grad_c2_ptr, grad_location_ptr, grad_weights_ptr);
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
grad_weights_ptr += 1;
|
| 180 |
+
|
| 181 |
+
// C3 Feature
|
| 182 |
+
h_im = loc_h * (h_c3 - 1); // align_corners = True
|
| 183 |
+
w_im = loc_w * (w_c3 - 1);
|
| 184 |
+
|
| 185 |
+
if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3)
|
| 186 |
+
{
|
| 187 |
+
const float *feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;
|
| 188 |
+
const float *grad_c3_ptr = grad_value_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;
|
| 189 |
+
ms_deform_attn_col2im_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col,
|
| 190 |
+
top_grad, weight_c3,
|
| 191 |
+
grad_c3_ptr, grad_location_ptr, grad_weights_ptr);
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
grad_weights_ptr += 1;
|
| 195 |
+
|
| 196 |
+
// C4 Feature
|
| 197 |
+
h_im = loc_h * (h_c4 - 1); // align_corners = True
|
| 198 |
+
w_im = loc_w * (w_c4 - 1);
|
| 199 |
+
|
| 200 |
+
if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4)
|
| 201 |
+
{
|
| 202 |
+
const float *feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;
|
| 203 |
+
const float *grad_c4_ptr = grad_value_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;
|
| 204 |
+
ms_deform_attn_col2im_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col,
|
| 205 |
+
top_grad, weight_c4,
|
| 206 |
+
grad_c4_ptr, grad_location_ptr, grad_weights_ptr);
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
grad_weights_ptr += 1;
|
| 210 |
+
|
| 211 |
+
// C5 Feature
|
| 212 |
+
h_im = loc_h * (h_c5 - 1); // align_corners = True
|
| 213 |
+
w_im = loc_w * (w_c5 - 1);
|
| 214 |
+
|
| 215 |
+
if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5)
|
| 216 |
+
{
|
| 217 |
+
const float *feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;
|
| 218 |
+
const float *grad_c5_ptr = grad_value_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;
|
| 219 |
+
ms_deform_attn_col2im_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col,
|
| 220 |
+
top_grad, weight_c5,
|
| 221 |
+
grad_c5_ptr, grad_location_ptr, grad_weights_ptr);
|
| 222 |
+
}
|
| 223 |
+
}
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
__global__ void ms_deformable_col2im_gpu_kernel_gm_c23456(
|
| 227 |
+
const float *grad_col,
|
| 228 |
+
const float *feat_c2,
|
| 229 |
+
const float *feat_c3,
|
| 230 |
+
const float *feat_c4,
|
| 231 |
+
const float *feat_c5,
|
| 232 |
+
const float *feat_c6,
|
| 233 |
+
const int h_c2, const int w_c2,
|
| 234 |
+
const int h_c3, const int w_c3,
|
| 235 |
+
const int h_c4, const int w_c4,
|
| 236 |
+
const int h_c5, const int w_c5,
|
| 237 |
+
const int h_c6, const int w_c6,
|
| 238 |
+
const float *data_sampling_loc,
|
| 239 |
+
const float *data_attn_weight,
|
| 240 |
+
const int batch_size,
|
| 241 |
+
const int channels,
|
| 242 |
+
const int num_views,
|
| 243 |
+
const int num_query,
|
| 244 |
+
const int num_point,
|
| 245 |
+
float *grad_value_c2,
|
| 246 |
+
float *grad_value_c3,
|
| 247 |
+
float *grad_value_c4,
|
| 248 |
+
float *grad_value_c5,
|
| 249 |
+
float *grad_value_c6,
|
| 250 |
+
float *grad_sampling_loc,
|
| 251 |
+
float *grad_attn_weight)
|
| 252 |
+
{
|
| 253 |
+
CUDA_KERNEL_LOOP(index, batch_size * num_query * channels * num_point)
|
| 254 |
+
{ // n: bs x query x channels
|
| 255 |
+
|
| 256 |
+
int _temp = index;
|
| 257 |
+
const int p_col = _temp % num_point;
|
| 258 |
+
_temp /= num_point;
|
| 259 |
+
const int c_col = _temp % channels;
|
| 260 |
+
_temp /= channels;
|
| 261 |
+
const int sampling_index = _temp;
|
| 262 |
+
_temp /= num_query;
|
| 263 |
+
const int b_col = _temp;
|
| 264 |
+
|
| 265 |
+
const float top_grad = grad_col[index];
|
| 266 |
+
|
| 267 |
+
// Sampling location in range [0, 1]
|
| 268 |
+
int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3;
|
| 269 |
+
const float loc_w = data_sampling_loc[data_loc_ptr];
|
| 270 |
+
const float loc_h = data_sampling_loc[data_loc_ptr + 1];
|
| 271 |
+
const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1));
|
| 272 |
+
|
| 273 |
+
// Attn weights
|
| 274 |
+
int data_weight_ptr = sampling_index * num_point * 5 + p_col * 5;
|
| 275 |
+
|
| 276 |
+
const float weight_c2 = data_attn_weight[data_weight_ptr];
|
| 277 |
+
const float weight_c3 = data_attn_weight[data_weight_ptr + 1];
|
| 278 |
+
const float weight_c4 = data_attn_weight[data_weight_ptr + 2];
|
| 279 |
+
const float weight_c5 = data_attn_weight[data_weight_ptr + 3];
|
| 280 |
+
const float weight_c6 = data_attn_weight[data_weight_ptr + 4];
|
| 281 |
+
|
| 282 |
+
// const float h_im = loc_h * spatial_h - 0.5; // align_corners = False
|
| 283 |
+
// const float w_im = loc_w * spatial_w - 0.5;
|
| 284 |
+
|
| 285 |
+
// C2 Feature
|
| 286 |
+
float h_im = loc_h * (h_c2 - 1); // align_corners = True
|
| 287 |
+
float w_im = loc_w * (w_c2 - 1);
|
| 288 |
+
|
| 289 |
+
float *grad_location_ptr = grad_sampling_loc + data_loc_ptr;
|
| 290 |
+
float *grad_weights_ptr = grad_attn_weight + data_weight_ptr;
|
| 291 |
+
|
| 292 |
+
if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2)
|
| 293 |
+
{
|
| 294 |
+
const float *feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;
|
| 295 |
+
const float *grad_c2_ptr = grad_value_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;
|
| 296 |
+
ms_deform_attn_col2im_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col,
|
| 297 |
+
top_grad, weight_c2,
|
| 298 |
+
grad_c2_ptr, grad_location_ptr, grad_weights_ptr);
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
grad_weights_ptr += 1;
|
| 302 |
+
|
| 303 |
+
// C3 Feature
|
| 304 |
+
h_im = loc_h * (h_c3 - 1); // align_corners = True
|
| 305 |
+
w_im = loc_w * (w_c3 - 1);
|
| 306 |
+
|
| 307 |
+
if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3)
|
| 308 |
+
{
|
| 309 |
+
const float *feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;
|
| 310 |
+
const float *grad_c3_ptr = grad_value_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;
|
| 311 |
+
ms_deform_attn_col2im_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col,
|
| 312 |
+
top_grad, weight_c3,
|
| 313 |
+
grad_c3_ptr, grad_location_ptr, grad_weights_ptr);
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
grad_weights_ptr += 1;
|
| 317 |
+
|
| 318 |
+
// C4 Feature
|
| 319 |
+
h_im = loc_h * (h_c4 - 1); // align_corners = True
|
| 320 |
+
w_im = loc_w * (w_c4 - 1);
|
| 321 |
+
|
| 322 |
+
if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4)
|
| 323 |
+
{
|
| 324 |
+
const float *feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;
|
| 325 |
+
const float *grad_c4_ptr = grad_value_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;
|
| 326 |
+
ms_deform_attn_col2im_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col,
|
| 327 |
+
top_grad, weight_c4,
|
| 328 |
+
grad_c4_ptr, grad_location_ptr, grad_weights_ptr);
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
grad_weights_ptr += 1;
|
| 332 |
+
|
| 333 |
+
// C5 Feature
|
| 334 |
+
h_im = loc_h * (h_c5 - 1); // align_corners = True
|
| 335 |
+
w_im = loc_w * (w_c5 - 1);
|
| 336 |
+
|
| 337 |
+
if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5)
|
| 338 |
+
{
|
| 339 |
+
const float *feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;
|
| 340 |
+
const float *grad_c5_ptr = grad_value_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;
|
| 341 |
+
ms_deform_attn_col2im_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col,
|
| 342 |
+
top_grad, weight_c5,
|
| 343 |
+
grad_c5_ptr, grad_location_ptr, grad_weights_ptr);
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
grad_weights_ptr += 1;
|
| 347 |
+
|
| 348 |
+
// C6 Feature
|
| 349 |
+
h_im = loc_h * (h_c6 - 1); // align_corners = True
|
| 350 |
+
w_im = loc_w * (w_c6 - 1);
|
| 351 |
+
|
| 352 |
+
if (h_im > -1 && w_im > -1 && h_im < h_c6 && w_im < w_c6)
|
| 353 |
+
{
|
| 354 |
+
const float *feat_c6_ptr = feat_c6 + b_col * num_views * h_c6 * w_c6 * channels + loc_v * h_c6 * w_c6 * channels;
|
| 355 |
+
const float *grad_c6_ptr = grad_value_c6 + b_col * num_views * h_c6 * w_c6 * channels + loc_v * h_c6 * w_c6 * channels;
|
| 356 |
+
ms_deform_attn_col2im_bilinear(feat_c6_ptr, h_c6, w_c6, channels, h_im, w_im, c_col,
|
| 357 |
+
top_grad, weight_c6,
|
| 358 |
+
grad_c6_ptr, grad_location_ptr, grad_weights_ptr);
|
| 359 |
+
}
|
| 360 |
+
}
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
void ms_deformable_col2im_cuda_c2345(
|
| 364 |
+
const float *grad_col,
|
| 365 |
+
const float *feat_c2,
|
| 366 |
+
const float *feat_c3,
|
| 367 |
+
const float *feat_c4,
|
| 368 |
+
const float *feat_c5,
|
| 369 |
+
const int h_c2, const int w_c2,
|
| 370 |
+
const int h_c3, const int w_c3,
|
| 371 |
+
const int h_c4, const int w_c4,
|
| 372 |
+
const int h_c5, const int w_c5,
|
| 373 |
+
const float *data_sampling_loc,
|
| 374 |
+
const float *data_attn_weight,
|
| 375 |
+
const int batch_size,
|
| 376 |
+
const int channels,
|
| 377 |
+
const int num_views,
|
| 378 |
+
const int num_query,
|
| 379 |
+
const int num_point,
|
| 380 |
+
float *grad_value_c2,
|
| 381 |
+
float *grad_value_c3,
|
| 382 |
+
float *grad_value_c4,
|
| 383 |
+
float *grad_value_c5,
|
| 384 |
+
float *grad_sampling_loc,
|
| 385 |
+
float *grad_attn_weight)
|
| 386 |
+
{
|
| 387 |
+
const int num_kernels = batch_size * num_query * channels * num_point;
|
| 388 |
+
const int num_threads = (channels * num_point > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels * num_point;
|
| 389 |
+
|
| 390 |
+
ms_deformable_col2im_gpu_kernel_gm_c2345 <<<GET_BLOCKS(num_kernels, num_threads), num_threads>>>(
|
| 391 |
+
grad_col, feat_c2, feat_c3, feat_c4, feat_c5,
|
| 392 |
+
h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5,
|
| 393 |
+
data_sampling_loc, data_attn_weight,
|
| 394 |
+
batch_size, channels, num_views, num_query, num_point,
|
| 395 |
+
grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5,
|
| 396 |
+
grad_sampling_loc, grad_attn_weight);
|
| 397 |
+
|
| 398 |
+
cudaError_t err = cudaGetLastError();
|
| 399 |
+
if (err != cudaSuccess)
|
| 400 |
+
{
|
| 401 |
+
printf("error in ms_deformable_col2im_cuda_c2345: %s\n", cudaGetErrorString(err));
|
| 402 |
+
}
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
void ms_deformable_col2im_cuda_c23456(
|
| 406 |
+
const float *grad_col,
|
| 407 |
+
const float *feat_c2,
|
| 408 |
+
const float *feat_c3,
|
| 409 |
+
const float *feat_c4,
|
| 410 |
+
const float *feat_c5,
|
| 411 |
+
const float *feat_c6,
|
| 412 |
+
const int h_c2, const int w_c2,
|
| 413 |
+
const int h_c3, const int w_c3,
|
| 414 |
+
const int h_c4, const int w_c4,
|
| 415 |
+
const int h_c5, const int w_c5,
|
| 416 |
+
const int h_c6, const int w_c6,
|
| 417 |
+
const float *data_sampling_loc,
|
| 418 |
+
const float *data_attn_weight,
|
| 419 |
+
const int batch_size,
|
| 420 |
+
const int channels,
|
| 421 |
+
const int num_views,
|
| 422 |
+
const int num_query,
|
| 423 |
+
const int num_point,
|
| 424 |
+
float *grad_value_c2,
|
| 425 |
+
float *grad_value_c3,
|
| 426 |
+
float *grad_value_c4,
|
| 427 |
+
float *grad_value_c5,
|
| 428 |
+
float *grad_value_c6,
|
| 429 |
+
float *grad_sampling_loc,
|
| 430 |
+
float *grad_attn_weight)
|
| 431 |
+
{
|
| 432 |
+
const int num_kernels = batch_size * num_query * channels * num_point;
|
| 433 |
+
const int num_threads = (channels * num_point > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels * num_point;
|
| 434 |
+
|
| 435 |
+
ms_deformable_col2im_gpu_kernel_gm_c23456 <<<GET_BLOCKS(num_kernels, num_threads), num_threads>>>(
|
| 436 |
+
grad_col, feat_c2, feat_c3, feat_c4, feat_c5, feat_c6,
|
| 437 |
+
h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6,
|
| 438 |
+
data_sampling_loc, data_attn_weight,
|
| 439 |
+
batch_size, channels, num_views, num_query, num_point,
|
| 440 |
+
grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6,
|
| 441 |
+
grad_sampling_loc, grad_attn_weight);
|
| 442 |
+
|
| 443 |
+
cudaError_t err = cudaGetLastError();
|
| 444 |
+
if (err != cudaSuccess)
|
| 445 |
+
{
|
| 446 |
+
printf("error in ms_deformable_col2im_cuda_c23456: %s\n", cudaGetErrorString(err));
|
| 447 |
+
}
|
| 448 |
+
}
|
models/csrc/msmv_sampling/msmv_sampling_forward.cu
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
* Modified from Deformable DETR
|
| 3 |
+
*/
|
| 4 |
+
|
| 5 |
+
#include <cstdio>
|
| 6 |
+
#include <algorithm>
|
| 7 |
+
#include <cstring>
|
| 8 |
+
#include <cuda_runtime.h>
|
| 9 |
+
#include <device_launch_parameters.h>
|
| 10 |
+
#include <torch/extension.h>
|
| 11 |
+
#include <ATen/ATen.h>
|
| 12 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 13 |
+
#include <THC/THCAtomics.cuh>
|
| 14 |
+
|
| 15 |
+
#define CUDA_KERNEL_LOOP(i, n) \
|
| 16 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
|
| 17 |
+
i < (n); \
|
| 18 |
+
i += blockDim.x * gridDim.x)
|
| 19 |
+
|
| 20 |
+
#define CUDA_NUM_THREADS 512
|
| 21 |
+
#define MAX_POINT 32
|
| 22 |
+
|
| 23 |
+
inline int GET_BLOCKS(const int N, const int num_threads) {
|
| 24 |
+
return (N + num_threads - 1) / num_threads;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
__device__ float ms_deform_attn_im2col_bilinear(
|
| 28 |
+
const float*& bottom_data,
|
| 29 |
+
const int& height, const int& width, const int& channels,
|
| 30 |
+
const float& h, const float& w, const int& c) {
|
| 31 |
+
|
| 32 |
+
const int h_low = floor(h);
|
| 33 |
+
const int w_low = floor(w);
|
| 34 |
+
const int h_high = h_low + 1;
|
| 35 |
+
const int w_high = w_low + 1;
|
| 36 |
+
|
| 37 |
+
const float lh = h - h_low;
|
| 38 |
+
const float lw = w - w_low;
|
| 39 |
+
const float hh = 1 - lh, hw = 1 - lw;
|
| 40 |
+
|
| 41 |
+
const int w_stride = channels;
|
| 42 |
+
const int h_stride = width * w_stride;
|
| 43 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
| 44 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
| 45 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
| 46 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
| 47 |
+
|
| 48 |
+
float v1 = 0;
|
| 49 |
+
if (h_low >= 0 && w_low >= 0) {
|
| 50 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + c;
|
| 51 |
+
v1 = bottom_data[ptr1];
|
| 52 |
+
}
|
| 53 |
+
float v2 = 0;
|
| 54 |
+
if (h_low >= 0 && w_high <= width - 1) {
|
| 55 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + c;
|
| 56 |
+
v2 = bottom_data[ptr2];
|
| 57 |
+
}
|
| 58 |
+
float v3 = 0;
|
| 59 |
+
if (h_high <= height - 1 && w_low >= 0) {
|
| 60 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + c;
|
| 61 |
+
v3 = bottom_data[ptr3];
|
| 62 |
+
}
|
| 63 |
+
float v4 = 0;
|
| 64 |
+
if (h_high <= height - 1 && w_high <= width - 1) {
|
| 65 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + c;
|
| 66 |
+
v4 = bottom_data[ptr4];
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 70 |
+
const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 71 |
+
|
| 72 |
+
return val;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
__global__ void ms_deformable_im2col_gpu_kernel_c2345(
|
| 76 |
+
const float* feat_c2,
|
| 77 |
+
const float* feat_c3,
|
| 78 |
+
const float* feat_c4,
|
| 79 |
+
const float* feat_c5,
|
| 80 |
+
const int h_c2, const int w_c2,
|
| 81 |
+
const int h_c3, const int w_c3,
|
| 82 |
+
const int h_c4, const int w_c4,
|
| 83 |
+
const int h_c5, const int w_c5,
|
| 84 |
+
const float* data_sampling_loc,
|
| 85 |
+
const float* data_attn_weight,
|
| 86 |
+
const int batch_size,
|
| 87 |
+
const int channels,
|
| 88 |
+
const int num_views,
|
| 89 |
+
const int num_query,
|
| 90 |
+
const int num_point,
|
| 91 |
+
float* data_col) {
|
| 92 |
+
|
| 93 |
+
float res[MAX_POINT];
|
| 94 |
+
|
| 95 |
+
CUDA_KERNEL_LOOP(index, batch_size * num_query * channels) { // n: bs x query x channels
|
| 96 |
+
int _temp = index;
|
| 97 |
+
const int c_col = _temp % channels;
|
| 98 |
+
_temp /= channels;
|
| 99 |
+
const int sampling_index = _temp;
|
| 100 |
+
_temp /= num_query;
|
| 101 |
+
const int b_col = _temp;
|
| 102 |
+
|
| 103 |
+
for (int p_col = 0; p_col < num_point; ++p_col) { res[p_col] = 0; }
|
| 104 |
+
|
| 105 |
+
for (int p_col = 0; p_col < num_point; ++p_col) {
|
| 106 |
+
// Sampling location in range [0, 1]
|
| 107 |
+
int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3;
|
| 108 |
+
const float loc_w = data_sampling_loc[data_loc_ptr];
|
| 109 |
+
const float loc_h = data_sampling_loc[data_loc_ptr + 1];
|
| 110 |
+
const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1));
|
| 111 |
+
|
| 112 |
+
// Attn weights
|
| 113 |
+
int data_weight_ptr = sampling_index * num_point * 4 + p_col * 4;
|
| 114 |
+
const float weight_c2 = data_attn_weight[data_weight_ptr];
|
| 115 |
+
const float weight_c3 = data_attn_weight[data_weight_ptr + 1];
|
| 116 |
+
const float weight_c4 = data_attn_weight[data_weight_ptr + 2];
|
| 117 |
+
const float weight_c5 = data_attn_weight[data_weight_ptr + 3];
|
| 118 |
+
|
| 119 |
+
//const float h_im = loc_h * spatial_h - 0.5; // align_corners = False
|
| 120 |
+
//const float w_im = loc_w * spatial_w - 0.5;
|
| 121 |
+
|
| 122 |
+
// C2 Feature
|
| 123 |
+
float h_im = loc_h * (h_c2 - 1); // align_corners = True
|
| 124 |
+
float w_im = loc_w * (w_c2 - 1);
|
| 125 |
+
|
| 126 |
+
if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2) {
|
| 127 |
+
const float* feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;
|
| 128 |
+
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col) * weight_c2;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
// C3 Feature
|
| 132 |
+
h_im = loc_h * (h_c3 - 1); // align_corners = True
|
| 133 |
+
w_im = loc_w * (w_c3 - 1);
|
| 134 |
+
|
| 135 |
+
if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3) {
|
| 136 |
+
const float* feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;
|
| 137 |
+
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col) * weight_c3;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
// C4 Feature
|
| 141 |
+
h_im = loc_h * (h_c4 - 1); // align_corners = True
|
| 142 |
+
w_im = loc_w * (w_c4 - 1);
|
| 143 |
+
|
| 144 |
+
if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4) {
|
| 145 |
+
const float* feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;
|
| 146 |
+
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col) * weight_c4;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
// C5 Feature
|
| 150 |
+
h_im = loc_h * (h_c5 - 1); // align_corners = True
|
| 151 |
+
w_im = loc_w * (w_c5 - 1);
|
| 152 |
+
|
| 153 |
+
if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5) {
|
| 154 |
+
const float* feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;
|
| 155 |
+
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col) * weight_c5;
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
for (int p_col = 0; p_col < num_point; ++p_col) {
|
| 160 |
+
float* data_col_ptr = data_col + index * num_point + p_col;
|
| 161 |
+
*data_col_ptr = res[p_col];
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
__global__ void ms_deformable_im2col_gpu_kernel_c23456(
|
| 167 |
+
const float* feat_c2,
|
| 168 |
+
const float* feat_c3,
|
| 169 |
+
const float* feat_c4,
|
| 170 |
+
const float* feat_c5,
|
| 171 |
+
const float* feat_c6,
|
| 172 |
+
const int h_c2, const int w_c2,
|
| 173 |
+
const int h_c3, const int w_c3,
|
| 174 |
+
const int h_c4, const int w_c4,
|
| 175 |
+
const int h_c5, const int w_c5,
|
| 176 |
+
const int h_c6, const int w_c6,
|
| 177 |
+
const float* data_sampling_loc,
|
| 178 |
+
const float* data_attn_weight,
|
| 179 |
+
const int batch_size,
|
| 180 |
+
const int channels,
|
| 181 |
+
const int num_views,
|
| 182 |
+
const int num_query,
|
| 183 |
+
const int num_point,
|
| 184 |
+
float* data_col) {
|
| 185 |
+
|
| 186 |
+
float res[MAX_POINT];
|
| 187 |
+
|
| 188 |
+
CUDA_KERNEL_LOOP(index, batch_size * num_query * channels) { // n: bs x query x channels
|
| 189 |
+
int _temp = index;
|
| 190 |
+
const int c_col = _temp % channels;
|
| 191 |
+
_temp /= channels;
|
| 192 |
+
const int sampling_index = _temp;
|
| 193 |
+
_temp /= num_query;
|
| 194 |
+
const int b_col = _temp;
|
| 195 |
+
|
| 196 |
+
for (int p_col = 0; p_col < num_point; ++p_col) { res[p_col] = 0; }
|
| 197 |
+
|
| 198 |
+
for (int p_col = 0; p_col < num_point; ++p_col) {
|
| 199 |
+
// Sampling location in range [0, 1]
|
| 200 |
+
int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3;
|
| 201 |
+
const float loc_w = data_sampling_loc[data_loc_ptr];
|
| 202 |
+
const float loc_h = data_sampling_loc[data_loc_ptr + 1];
|
| 203 |
+
const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1));
|
| 204 |
+
|
| 205 |
+
// Attn weights
|
| 206 |
+
int data_weight_ptr = sampling_index * num_point * 5 + p_col * 5;
|
| 207 |
+
const float weight_c2 = data_attn_weight[data_weight_ptr];
|
| 208 |
+
const float weight_c3 = data_attn_weight[data_weight_ptr + 1];
|
| 209 |
+
const float weight_c4 = data_attn_weight[data_weight_ptr + 2];
|
| 210 |
+
const float weight_c5 = data_attn_weight[data_weight_ptr + 3];
|
| 211 |
+
const float weight_c6 = data_attn_weight[data_weight_ptr + 4];
|
| 212 |
+
|
| 213 |
+
//const float h_im = loc_h * spatial_h - 0.5; // align_corners = False
|
| 214 |
+
//const float w_im = loc_w * spatial_w - 0.5;
|
| 215 |
+
|
| 216 |
+
// C2 Feature
|
| 217 |
+
float h_im = loc_h * (h_c2 - 1); // align_corners = True
|
| 218 |
+
float w_im = loc_w * (w_c2 - 1);
|
| 219 |
+
|
| 220 |
+
if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2) {
|
| 221 |
+
const float* feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;
|
| 222 |
+
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col) * weight_c2;
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
// C3 Feature
|
| 226 |
+
h_im = loc_h * (h_c3 - 1); // align_corners = True
|
| 227 |
+
w_im = loc_w * (w_c3 - 1);
|
| 228 |
+
|
| 229 |
+
if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3) {
|
| 230 |
+
const float* feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;
|
| 231 |
+
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col) * weight_c3;
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
// C4 Feature
|
| 235 |
+
h_im = loc_h * (h_c4 - 1); // align_corners = True
|
| 236 |
+
w_im = loc_w * (w_c4 - 1);
|
| 237 |
+
|
| 238 |
+
if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4) {
|
| 239 |
+
const float* feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;
|
| 240 |
+
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col) * weight_c4;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
// C5 Feature
|
| 244 |
+
h_im = loc_h * (h_c5 - 1); // align_corners = True
|
| 245 |
+
w_im = loc_w * (w_c5 - 1);
|
| 246 |
+
|
| 247 |
+
if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5) {
|
| 248 |
+
const float* feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;
|
| 249 |
+
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col) * weight_c5;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
// C6 Feature
|
| 253 |
+
h_im = loc_h * (h_c6 - 1); // align_corners = True
|
| 254 |
+
w_im = loc_w * (w_c6 - 1);
|
| 255 |
+
|
| 256 |
+
if (h_im > -1 && w_im > -1 && h_im < h_c6 && w_im < w_c6) {
|
| 257 |
+
const float* feat_c6_ptr = feat_c6 + b_col * num_views * h_c6 * w_c6 * channels + loc_v * h_c6 * w_c6 * channels;
|
| 258 |
+
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c6_ptr, h_c6, w_c6, channels, h_im, w_im, c_col) * weight_c6;
|
| 259 |
+
}
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
for (int p_col = 0; p_col < num_point; ++p_col) {
|
| 263 |
+
float* data_col_ptr = data_col + index * num_point + p_col;
|
| 264 |
+
*data_col_ptr = res[p_col];
|
| 265 |
+
}
|
| 266 |
+
}
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
void ms_deformable_im2col_cuda_c2345(
|
| 270 |
+
const float* feat_c2,
|
| 271 |
+
const float* feat_c3,
|
| 272 |
+
const float* feat_c4,
|
| 273 |
+
const float* feat_c5,
|
| 274 |
+
const int h_c2, const int w_c2,
|
| 275 |
+
const int h_c3, const int w_c3,
|
| 276 |
+
const int h_c4, const int w_c4,
|
| 277 |
+
const int h_c5, const int w_c5,
|
| 278 |
+
const float* data_sampling_loc,
|
| 279 |
+
const float* data_attn_weight,
|
| 280 |
+
const int batch_size,
|
| 281 |
+
const int channels,
|
| 282 |
+
const int num_views,
|
| 283 |
+
const int num_query,
|
| 284 |
+
const int num_point,
|
| 285 |
+
float* data_col) {
|
| 286 |
+
|
| 287 |
+
const int num_kernels = batch_size * num_query * channels;
|
| 288 |
+
const int num_threads = CUDA_NUM_THREADS;
|
| 289 |
+
|
| 290 |
+
ms_deformable_im2col_gpu_kernel_c2345 <<<GET_BLOCKS(num_kernels, num_threads), num_threads>>> (
|
| 291 |
+
feat_c2, feat_c3, feat_c4, feat_c5, h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5,
|
| 292 |
+
data_sampling_loc, data_attn_weight, batch_size, channels, num_views, num_query, num_point, data_col
|
| 293 |
+
);
|
| 294 |
+
|
| 295 |
+
cudaError_t err = cudaGetLastError();
|
| 296 |
+
if (err != cudaSuccess) {
|
| 297 |
+
printf("error in ms_deformable_im2col_cuda_c2345: %s\n", cudaGetErrorString(err));
|
| 298 |
+
}
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
void ms_deformable_im2col_cuda_c23456(
|
| 302 |
+
const float* feat_c2,
|
| 303 |
+
const float* feat_c3,
|
| 304 |
+
const float* feat_c4,
|
| 305 |
+
const float* feat_c5,
|
| 306 |
+
const float* feat_c6,
|
| 307 |
+
const int h_c2, const int w_c2,
|
| 308 |
+
const int h_c3, const int w_c3,
|
| 309 |
+
const int h_c4, const int w_c4,
|
| 310 |
+
const int h_c5, const int w_c5,
|
| 311 |
+
const int h_c6, const int w_c6,
|
| 312 |
+
const float* data_sampling_loc,
|
| 313 |
+
const float* data_attn_weight,
|
| 314 |
+
const int batch_size,
|
| 315 |
+
const int channels,
|
| 316 |
+
const int num_views,
|
| 317 |
+
const int num_query,
|
| 318 |
+
const int num_point,
|
| 319 |
+
float* data_col) {
|
| 320 |
+
|
| 321 |
+
const int num_kernels = batch_size * num_query * channels;
|
| 322 |
+
const int num_threads = CUDA_NUM_THREADS;
|
| 323 |
+
|
| 324 |
+
ms_deformable_im2col_gpu_kernel_c23456 <<<GET_BLOCKS(num_kernels, num_threads), num_threads>>> (
|
| 325 |
+
feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6,
|
| 326 |
+
data_sampling_loc, data_attn_weight, batch_size, channels, num_views, num_query, num_point, data_col
|
| 327 |
+
);
|
| 328 |
+
|
| 329 |
+
cudaError_t err = cudaGetLastError();
|
| 330 |
+
if (err != cudaSuccess) {
|
| 331 |
+
printf("error in ms_deformable_im2col_cuda_c23456: %s\n", cudaGetErrorString(err));
|
| 332 |
+
}
|
| 333 |
+
}
|
models/csrc/setup.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup
|
| 2 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_ext_modules():
|
| 6 |
+
return [
|
| 7 |
+
CUDAExtension(
|
| 8 |
+
name='_msmv_sampling_cuda',
|
| 9 |
+
sources=[
|
| 10 |
+
'msmv_sampling/msmv_sampling.cpp',
|
| 11 |
+
'msmv_sampling/msmv_sampling_forward.cu',
|
| 12 |
+
'msmv_sampling/msmv_sampling_backward.cu'
|
| 13 |
+
],
|
| 14 |
+
include_dirs=['msmv_sampling']
|
| 15 |
+
)
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
setup(
|
| 20 |
+
name='csrc',
|
| 21 |
+
ext_modules=get_ext_modules(),
|
| 22 |
+
cmdclass={'build_ext': BuildExtension}
|
| 23 |
+
)
|
| 24 |
+
|
models/csrc/wrapper.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from ._msmv_sampling_cuda import _ms_deform_attn_cuda_c2345_forward, _ms_deform_attn_cuda_c2345_backward
|
| 4 |
+
from ._msmv_sampling_cuda import _ms_deform_attn_cuda_c23456_forward, _ms_deform_attn_cuda_c23456_backward
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights):
|
| 8 |
+
"""
|
| 9 |
+
value: [B, N, H1W1 + H2W2..., C]
|
| 10 |
+
sampling_locations: [B, Q, P, 3]
|
| 11 |
+
scale_weights: [B, Q, P, 4]
|
| 12 |
+
"""
|
| 13 |
+
assert scale_weights.shape[-1] == len(mlvl_feats)
|
| 14 |
+
|
| 15 |
+
B, _, _, _, C = mlvl_feats[0].shape
|
| 16 |
+
_, Q, P, _ = sampling_locations.shape
|
| 17 |
+
|
| 18 |
+
sampling_locations = sampling_locations * 2 - 1
|
| 19 |
+
sampling_locations = sampling_locations[:, :, :, None, :] # [B, Q, P, 1, 3]
|
| 20 |
+
|
| 21 |
+
final = torch.zeros([B, C, Q, P], device=mlvl_feats[0].device)
|
| 22 |
+
|
| 23 |
+
for lvl, feat in enumerate(mlvl_feats):
|
| 24 |
+
feat = feat.permute(0, 4, 1, 2, 3)
|
| 25 |
+
out = F.grid_sample(
|
| 26 |
+
feat, sampling_locations, mode='bilinear',
|
| 27 |
+
padding_mode='zeros', align_corners=True,
|
| 28 |
+
)[..., 0] # [B, C, Q, P]
|
| 29 |
+
out = out * scale_weights[..., lvl].reshape(B, 1, Q, P)
|
| 30 |
+
final += out
|
| 31 |
+
|
| 32 |
+
return final.permute(0, 2, 1, 3)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class MSMVSamplingC2345(torch.autograd.Function):
|
| 36 |
+
@staticmethod
|
| 37 |
+
def forward(ctx, feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights):
|
| 38 |
+
ctx.save_for_backward(feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights)
|
| 39 |
+
|
| 40 |
+
assert callable(_ms_deform_attn_cuda_c2345_forward)
|
| 41 |
+
return _ms_deform_attn_cuda_c2345_forward(
|
| 42 |
+
feat_c2, feat_c3, feat_c4, feat_c5,
|
| 43 |
+
sampling_locations, scale_weights)
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def backward(ctx, grad_output):
|
| 47 |
+
feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights = ctx.saved_tensors
|
| 48 |
+
|
| 49 |
+
assert callable(_ms_deform_attn_cuda_c2345_backward)
|
| 50 |
+
grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight = _ms_deform_attn_cuda_c2345_backward(grad_output.contiguous(),
|
| 51 |
+
feat_c2, feat_c3, feat_c4, feat_c5,
|
| 52 |
+
sampling_locations, scale_weights
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
return grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class MSMVSamplingC23456(torch.autograd.Function):
|
| 59 |
+
@staticmethod
|
| 60 |
+
def forward(ctx, feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights):
|
| 61 |
+
ctx.save_for_backward(feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights)
|
| 62 |
+
|
| 63 |
+
assert callable(_ms_deform_attn_cuda_c23456_forward)
|
| 64 |
+
return _ms_deform_attn_cuda_c23456_forward(
|
| 65 |
+
feat_c2, feat_c3, feat_c4, feat_c5, feat_c6,
|
| 66 |
+
sampling_locations, scale_weights)
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def backward(ctx, grad_output):
|
| 70 |
+
feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights = ctx.saved_tensors
|
| 71 |
+
|
| 72 |
+
assert callable(_ms_deform_attn_cuda_c23456_backward)
|
| 73 |
+
grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight = _ms_deform_attn_cuda_c23456_backward(grad_output.contiguous(),
|
| 74 |
+
feat_c2, feat_c3, feat_c4, feat_c5, feat_c6,
|
| 75 |
+
sampling_locations, scale_weights
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
return grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def msmv_sampling(mlvl_feats, sampling_locations, scale_weights):
|
| 82 |
+
if len(mlvl_feats) == 4:
|
| 83 |
+
return MSMVSamplingC2345.apply(*mlvl_feats, sampling_locations, scale_weights)
|
| 84 |
+
elif len(mlvl_feats) == 5:
|
| 85 |
+
return MSMVSamplingC23456.apply(*mlvl_feats, sampling_locations, scale_weights)
|
| 86 |
+
else:
|
| 87 |
+
return msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights)
|
models/sparsebev.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import queue
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from mmcv.runner import force_fp32, auto_fp16
|
| 5 |
+
from mmcv.runner import get_dist_info
|
| 6 |
+
from mmcv.runner.fp16_utils import cast_tensor_type
|
| 7 |
+
from mmdet.models import DETECTORS
|
| 8 |
+
from mmdet3d.core import bbox3d2result
|
| 9 |
+
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
|
| 10 |
+
from .utils import GridMask, pad_multiple, GpuPhotoMetricDistortion
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@DETECTORS.register_module()
|
| 14 |
+
class SparseBEV(MVXTwoStageDetector):
|
| 15 |
+
def __init__(self,
|
| 16 |
+
data_aug=None,
|
| 17 |
+
stop_prev_grad=False,
|
| 18 |
+
pts_voxel_layer=None,
|
| 19 |
+
pts_voxel_encoder=None,
|
| 20 |
+
pts_middle_encoder=None,
|
| 21 |
+
pts_fusion_layer=None,
|
| 22 |
+
img_backbone=None,
|
| 23 |
+
pts_backbone=None,
|
| 24 |
+
img_neck=None,
|
| 25 |
+
pts_neck=None,
|
| 26 |
+
pts_bbox_head=None,
|
| 27 |
+
img_roi_head=None,
|
| 28 |
+
img_rpn_head=None,
|
| 29 |
+
train_cfg=None,
|
| 30 |
+
test_cfg=None,
|
| 31 |
+
pretrained=None):
|
| 32 |
+
super(SparseBEV, self).__init__(pts_voxel_layer, pts_voxel_encoder,
|
| 33 |
+
pts_middle_encoder, pts_fusion_layer,
|
| 34 |
+
img_backbone, pts_backbone, img_neck, pts_neck,
|
| 35 |
+
pts_bbox_head, img_roi_head, img_rpn_head,
|
| 36 |
+
train_cfg, test_cfg, pretrained)
|
| 37 |
+
self.data_aug = data_aug
|
| 38 |
+
self.stop_prev_grad = stop_prev_grad
|
| 39 |
+
self.color_aug = GpuPhotoMetricDistortion()
|
| 40 |
+
self.grid_mask = GridMask(ratio=0.5, prob=0.7)
|
| 41 |
+
self.use_grid_mask = True
|
| 42 |
+
|
| 43 |
+
self.memory = {}
|
| 44 |
+
self.queue = queue.Queue()
|
| 45 |
+
|
| 46 |
+
@auto_fp16(apply_to=('img'), out_fp32=True)
|
| 47 |
+
def extract_img_feat(self, img):
|
| 48 |
+
if self.use_grid_mask:
|
| 49 |
+
img = self.grid_mask(img)
|
| 50 |
+
|
| 51 |
+
img_feats = self.img_backbone(img)
|
| 52 |
+
|
| 53 |
+
if isinstance(img_feats, dict):
|
| 54 |
+
img_feats = list(img_feats.values())
|
| 55 |
+
|
| 56 |
+
if self.with_img_neck:
|
| 57 |
+
img_feats = self.img_neck(img_feats)
|
| 58 |
+
|
| 59 |
+
return img_feats
|
| 60 |
+
|
| 61 |
+
def extract_feat(self, img, img_metas):
|
| 62 |
+
if isinstance(img, list):
|
| 63 |
+
img = torch.stack(img, dim=0)
|
| 64 |
+
|
| 65 |
+
assert img.dim() == 5
|
| 66 |
+
|
| 67 |
+
B, N, C, H, W = img.size()
|
| 68 |
+
img = img.view(B * N, C, H, W)
|
| 69 |
+
img = img.float()
|
| 70 |
+
|
| 71 |
+
# move some augmentations to GPU
|
| 72 |
+
if self.data_aug is not None:
|
| 73 |
+
if 'img_color_aug' in self.data_aug and self.data_aug['img_color_aug'] and self.training:
|
| 74 |
+
img = self.color_aug(img)
|
| 75 |
+
|
| 76 |
+
if 'img_norm_cfg' in self.data_aug:
|
| 77 |
+
img_norm_cfg = self.data_aug['img_norm_cfg']
|
| 78 |
+
|
| 79 |
+
norm_mean = torch.tensor(img_norm_cfg['mean'], device=img.device)
|
| 80 |
+
norm_std = torch.tensor(img_norm_cfg['std'], device=img.device)
|
| 81 |
+
|
| 82 |
+
if img_norm_cfg['to_rgb']:
|
| 83 |
+
img = img[:, [2, 1, 0], :, :] # BGR to RGB
|
| 84 |
+
|
| 85 |
+
img = img - norm_mean.reshape(1, 3, 1, 1)
|
| 86 |
+
img = img / norm_std.reshape(1, 3, 1, 1)
|
| 87 |
+
|
| 88 |
+
for b in range(B):
|
| 89 |
+
img_shape = (img.shape[2], img.shape[3], img.shape[1])
|
| 90 |
+
img_metas[b]['img_shape'] = [img_shape for _ in range(N)]
|
| 91 |
+
img_metas[b]['ori_shape'] = [img_shape for _ in range(N)]
|
| 92 |
+
|
| 93 |
+
if 'img_pad_cfg' in self.data_aug:
|
| 94 |
+
img_pad_cfg = self.data_aug['img_pad_cfg']
|
| 95 |
+
img = pad_multiple(img, img_metas, size_divisor=img_pad_cfg['size_divisor'])
|
| 96 |
+
|
| 97 |
+
input_shape = img.shape[-2:]
|
| 98 |
+
# update real input shape of each single img
|
| 99 |
+
for img_meta in img_metas:
|
| 100 |
+
img_meta.update(input_shape=input_shape)
|
| 101 |
+
|
| 102 |
+
if self.training and self.stop_prev_grad:
|
| 103 |
+
H, W = input_shape
|
| 104 |
+
img = img.reshape(B, -1, 6, C, H, W)
|
| 105 |
+
|
| 106 |
+
img_grad = img[:, :1]
|
| 107 |
+
img_nograd = img[:, 1:]
|
| 108 |
+
|
| 109 |
+
all_img_feats = [self.extract_img_feat(img_grad.reshape(-1, C, H, W))]
|
| 110 |
+
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
self.eval()
|
| 113 |
+
for k in range(img_nograd.shape[1]):
|
| 114 |
+
all_img_feats.append(self.extract_img_feat(img_nograd[:, k].reshape(-1, C, H, W)))
|
| 115 |
+
self.train()
|
| 116 |
+
|
| 117 |
+
img_feats = []
|
| 118 |
+
for lvl in range(len(all_img_feats[0])):
|
| 119 |
+
C, H, W = all_img_feats[0][lvl].shape[1:]
|
| 120 |
+
img_feat = torch.cat([feat[lvl].reshape(B, -1, 6, C, H, W) for feat in all_img_feats], dim=1)
|
| 121 |
+
img_feat = img_feat.reshape(-1, C, H, W)
|
| 122 |
+
img_feats.append(img_feat)
|
| 123 |
+
else:
|
| 124 |
+
img_feats = self.extract_img_feat(img)
|
| 125 |
+
|
| 126 |
+
img_feats_reshaped = []
|
| 127 |
+
for img_feat in img_feats:
|
| 128 |
+
BN, C, H, W = img_feat.size()
|
| 129 |
+
img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W))
|
| 130 |
+
|
| 131 |
+
return img_feats_reshaped
|
| 132 |
+
|
| 133 |
+
def forward_pts_train(self,
|
| 134 |
+
pts_feats,
|
| 135 |
+
gt_bboxes_3d,
|
| 136 |
+
gt_labels_3d,
|
| 137 |
+
img_metas,
|
| 138 |
+
gt_bboxes_ignore=None):
|
| 139 |
+
"""Forward function for point cloud branch.
|
| 140 |
+
Args:
|
| 141 |
+
pts_feats (list[torch.Tensor]): Features of point cloud branch
|
| 142 |
+
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
|
| 143 |
+
boxes for each sample.
|
| 144 |
+
gt_labels_3d (list[torch.Tensor]): Ground truth labels for
|
| 145 |
+
boxes of each sampole
|
| 146 |
+
img_metas (list[dict]): Meta information of samples.
|
| 147 |
+
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
|
| 148 |
+
boxes to be ignored. Defaults to None.
|
| 149 |
+
Returns:
|
| 150 |
+
dict: Losses of each branch.
|
| 151 |
+
"""
|
| 152 |
+
outs = self.pts_bbox_head(pts_feats, img_metas)
|
| 153 |
+
loss_inputs = [gt_bboxes_3d, gt_labels_3d, outs]
|
| 154 |
+
losses = self.pts_bbox_head.loss(*loss_inputs)
|
| 155 |
+
|
| 156 |
+
return losses
|
| 157 |
+
|
| 158 |
+
@force_fp32(apply_to=('img', 'points'))
|
| 159 |
+
def forward(self, return_loss=True, **kwargs):
|
| 160 |
+
"""Calls either forward_train or forward_test depending on whether
|
| 161 |
+
return_loss=True.
|
| 162 |
+
Note this setting will change the expected inputs. When
|
| 163 |
+
`return_loss=True`, img and img_metas are single-nested (i.e.
|
| 164 |
+
torch.Tensor and list[dict]), and when `resturn_loss=False`, img and
|
| 165 |
+
img_metas should be double nested (i.e. list[torch.Tensor],
|
| 166 |
+
list[list[dict]]), with the outer list indicating test time
|
| 167 |
+
augmentations.
|
| 168 |
+
"""
|
| 169 |
+
if return_loss:
|
| 170 |
+
return self.forward_train(**kwargs)
|
| 171 |
+
else:
|
| 172 |
+
return self.forward_test(**kwargs)
|
| 173 |
+
|
| 174 |
+
def forward_train(self,
|
| 175 |
+
points=None,
|
| 176 |
+
img_metas=None,
|
| 177 |
+
gt_bboxes_3d=None,
|
| 178 |
+
gt_labels_3d=None,
|
| 179 |
+
gt_labels=None,
|
| 180 |
+
gt_bboxes=None,
|
| 181 |
+
img=None,
|
| 182 |
+
proposals=None,
|
| 183 |
+
gt_bboxes_ignore=None,
|
| 184 |
+
img_depth=None,
|
| 185 |
+
img_mask=None):
|
| 186 |
+
"""Forward training function.
|
| 187 |
+
Args:
|
| 188 |
+
points (list[torch.Tensor], optional): Points of each sample.
|
| 189 |
+
Defaults to None.
|
| 190 |
+
img_metas (list[dict], optional): Meta information of each sample.
|
| 191 |
+
Defaults to None.
|
| 192 |
+
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
|
| 193 |
+
Ground truth 3D boxes. Defaults to None.
|
| 194 |
+
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
|
| 195 |
+
of 3D boxes. Defaults to None.
|
| 196 |
+
gt_labels (list[torch.Tensor], optional): Ground truth labels
|
| 197 |
+
of 2D boxes in images. Defaults to None.
|
| 198 |
+
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
|
| 199 |
+
images. Defaults to None.
|
| 200 |
+
img (torch.Tensor optional): Images of each sample with shape
|
| 201 |
+
(N, C, H, W). Defaults to None.
|
| 202 |
+
proposals ([list[torch.Tensor], optional): Predicted proposals
|
| 203 |
+
used for training Fast RCNN. Defaults to None.
|
| 204 |
+
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
|
| 205 |
+
2D boxes in images to be ignored. Defaults to None.
|
| 206 |
+
Returns:
|
| 207 |
+
dict: Losses of different branches.
|
| 208 |
+
"""
|
| 209 |
+
img_feats = self.extract_feat(img, img_metas)
|
| 210 |
+
|
| 211 |
+
for i in range(len(img_metas)):
|
| 212 |
+
img_metas[i]['gt_bboxes_3d'] = gt_bboxes_3d[i]
|
| 213 |
+
img_metas[i]['gt_labels_3d'] = gt_labels_3d[i]
|
| 214 |
+
|
| 215 |
+
losses = self.forward_pts_train(img_feats, gt_bboxes_3d, gt_labels_3d, img_metas, gt_bboxes_ignore)
|
| 216 |
+
|
| 217 |
+
return losses
|
| 218 |
+
|
| 219 |
+
def forward_test(self, img_metas, img=None, **kwargs):
|
| 220 |
+
for var, name in [(img_metas, 'img_metas')]:
|
| 221 |
+
if not isinstance(var, list):
|
| 222 |
+
raise TypeError('{} must be a list, but got {}'.format(
|
| 223 |
+
name, type(var)))
|
| 224 |
+
img = [img] if img is None else img
|
| 225 |
+
return self.simple_test(img_metas[0], img[0], **kwargs)
|
| 226 |
+
|
| 227 |
+
def simple_test_pts(self, x, img_metas, rescale=False):
|
| 228 |
+
outs = self.pts_bbox_head(x, img_metas)
|
| 229 |
+
bbox_list = self.pts_bbox_head.get_bboxes(outs, img_metas[0], rescale=rescale)
|
| 230 |
+
|
| 231 |
+
bbox_results = [
|
| 232 |
+
bbox3d2result(bboxes, scores, labels)
|
| 233 |
+
for bboxes, scores, labels in bbox_list
|
| 234 |
+
]
|
| 235 |
+
|
| 236 |
+
return bbox_results
|
| 237 |
+
|
| 238 |
+
def simple_test(self, img_metas, img=None, rescale=False):
|
| 239 |
+
world_size = get_dist_info()[1]
|
| 240 |
+
if world_size == 1: # online
|
| 241 |
+
return self.simple_test_online(img_metas, img, rescale)
|
| 242 |
+
elif world_size > 1: # offline
|
| 243 |
+
return self.simple_test_offline(img_metas, img, rescale)
|
| 244 |
+
|
| 245 |
+
def simple_test_offline(self, img_metas, img=None, rescale=False):
|
| 246 |
+
self.fp16_enabled = False
|
| 247 |
+
img_feats = self.extract_feat(img=img, img_metas=img_metas)
|
| 248 |
+
|
| 249 |
+
bbox_list = [dict() for _ in range(len(img_metas))]
|
| 250 |
+
bbox_pts = self.simple_test_pts(img_feats, img_metas, rescale=rescale)
|
| 251 |
+
for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
|
| 252 |
+
result_dict['pts_bbox'] = pts_bbox
|
| 253 |
+
|
| 254 |
+
return bbox_list
|
| 255 |
+
|
| 256 |
+
def simple_test_online(self, img_metas, img=None, rescale=False):
|
| 257 |
+
self.fp16_enabled = False
|
| 258 |
+
assert len(img_metas) == 1 # batch_size = 1
|
| 259 |
+
|
| 260 |
+
B, N, C, H, W = img.shape
|
| 261 |
+
img = img.reshape(B, N//6, 6, C, H, W)
|
| 262 |
+
|
| 263 |
+
img_filenames = img_metas[0]['filename']
|
| 264 |
+
num_frames = len(img_filenames) // 6
|
| 265 |
+
# assert num_frames == img.shape[1]
|
| 266 |
+
|
| 267 |
+
img_shape = (H, W, C)
|
| 268 |
+
img_metas[0]['img_shape'] = [img_shape for _ in range(len(img_filenames))]
|
| 269 |
+
img_metas[0]['ori_shape'] = [img_shape for _ in range(len(img_filenames))]
|
| 270 |
+
img_metas[0]['pad_shape'] = [img_shape for _ in range(len(img_filenames))]
|
| 271 |
+
|
| 272 |
+
img_feats_large, img_metas_large = [], []
|
| 273 |
+
|
| 274 |
+
for i in range(num_frames):
|
| 275 |
+
img_indices = list(np.arange(i * 6, (i + 1) * 6))
|
| 276 |
+
|
| 277 |
+
img_curr_large = img[:, 0] # [B, 6, C, H, W]
|
| 278 |
+
img_metas_curr_large = [{}]
|
| 279 |
+
|
| 280 |
+
for k in img_metas[0].keys():
|
| 281 |
+
if isinstance(img_metas[0][k], list):
|
| 282 |
+
img_metas_curr_large[0][k] = [img_metas[0][k][i] for i in img_indices]
|
| 283 |
+
|
| 284 |
+
if img_filenames[img_indices[0]] in self.memory:
|
| 285 |
+
img_feats_curr_large = self.memory[img_filenames[img_indices[0]]]
|
| 286 |
+
else:
|
| 287 |
+
assert i == 0
|
| 288 |
+
img_feats_curr_large = self.extract_feat(img_curr_large, img_metas_curr_large)
|
| 289 |
+
self.memory[img_filenames[img_indices[0]]] = img_feats_curr_large
|
| 290 |
+
self.queue.put(img_filenames[img_indices[0]])
|
| 291 |
+
|
| 292 |
+
img_feats_large.append(img_feats_curr_large)
|
| 293 |
+
img_metas_large.append(img_metas_curr_large)
|
| 294 |
+
|
| 295 |
+
# reorganize
|
| 296 |
+
feat_levels = len(img_feats_large[0])
|
| 297 |
+
img_feats_large_reorganized = []
|
| 298 |
+
for j in range(feat_levels):
|
| 299 |
+
feat_l = torch.cat([img_feats_large[i][j] for i in range(len(img_feats_large))], dim=0)
|
| 300 |
+
feat_l = feat_l.flatten(0, 1)[None, ...]
|
| 301 |
+
img_feats_large_reorganized.append(feat_l)
|
| 302 |
+
|
| 303 |
+
img_metas_large_reorganized = img_metas_large[0]
|
| 304 |
+
for i in range(1, len(img_metas_large)):
|
| 305 |
+
for k, v in img_metas_large[i][0].items():
|
| 306 |
+
if isinstance(v, list):
|
| 307 |
+
img_metas_large_reorganized[0][k].extend(v)
|
| 308 |
+
|
| 309 |
+
img_feats = img_feats_large_reorganized
|
| 310 |
+
img_metas = img_metas_large_reorganized
|
| 311 |
+
img_feats = cast_tensor_type(img_feats, torch.half, torch.float32)
|
| 312 |
+
|
| 313 |
+
bbox_list = [dict() for _ in range(1)]
|
| 314 |
+
bbox_pts = self.simple_test_pts(img_feats, img_metas, rescale=rescale)
|
| 315 |
+
for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
|
| 316 |
+
result_dict['pts_bbox'] = pts_bbox
|
| 317 |
+
|
| 318 |
+
while self.queue.qsize() >= 8:
|
| 319 |
+
pop_key = self.queue.get()
|
| 320 |
+
self.memory.pop(pop_key)
|
| 321 |
+
|
| 322 |
+
return bbox_list
|
models/sparsebev_head.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from mmcv.runner import force_fp32
|
| 5 |
+
from mmdet.core import multi_apply, reduce_mean
|
| 6 |
+
from mmdet.models import HEADS
|
| 7 |
+
from mmdet.models.dense_heads import DETRHead
|
| 8 |
+
from mmdet3d.core.bbox.coders import build_bbox_coder
|
| 9 |
+
from mmdet3d.core.bbox.structures.lidar_box3d import LiDARInstance3DBoxes
|
| 10 |
+
from .bbox.utils import normalize_bbox, encode_bbox
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@HEADS.register_module()
|
| 14 |
+
class SparseBEVHead(DETRHead):
|
| 15 |
+
def __init__(self,
|
| 16 |
+
*args,
|
| 17 |
+
num_classes,
|
| 18 |
+
in_channels,
|
| 19 |
+
query_denoising=True,
|
| 20 |
+
query_denoising_groups=10,
|
| 21 |
+
bbox_coder=None,
|
| 22 |
+
code_size=10,
|
| 23 |
+
code_weights=[1.0] * 10,
|
| 24 |
+
train_cfg=dict(),
|
| 25 |
+
test_cfg=dict(max_per_img=100),
|
| 26 |
+
**kwargs):
|
| 27 |
+
self.code_size = code_size
|
| 28 |
+
self.code_weights = code_weights
|
| 29 |
+
self.num_classes = num_classes
|
| 30 |
+
self.in_channels = in_channels
|
| 31 |
+
self.train_cfg = train_cfg
|
| 32 |
+
self.test_cfg = test_cfg
|
| 33 |
+
self.fp16_enabled = False
|
| 34 |
+
self.embed_dims = in_channels
|
| 35 |
+
|
| 36 |
+
super(SparseBEVHead, self).__init__(num_classes, in_channels, train_cfg=train_cfg, test_cfg=test_cfg, **kwargs)
|
| 37 |
+
|
| 38 |
+
self.code_weights = nn.Parameter(torch.tensor(self.code_weights), requires_grad=False)
|
| 39 |
+
self.bbox_coder = build_bbox_coder(bbox_coder)
|
| 40 |
+
self.pc_range = self.bbox_coder.pc_range
|
| 41 |
+
|
| 42 |
+
self.dn_enabled = query_denoising
|
| 43 |
+
self.dn_group_num = query_denoising_groups
|
| 44 |
+
self.dn_weight = 1.0
|
| 45 |
+
self.dn_bbox_noise_scale = 0.5
|
| 46 |
+
self.dn_label_noise_scale = 0.5
|
| 47 |
+
|
| 48 |
+
def _init_layers(self):
|
| 49 |
+
self.init_query_bbox = nn.Embedding(self.num_query, 10) # (x, y, z, w, l, h, sin, cos, vx, vy)
|
| 50 |
+
self.label_enc = nn.Embedding(self.num_classes + 1, self.embed_dims - 1) # DAB-DETR
|
| 51 |
+
|
| 52 |
+
nn.init.zeros_(self.init_query_bbox.weight[:, 2:3])
|
| 53 |
+
nn.init.zeros_(self.init_query_bbox.weight[:, 8:10])
|
| 54 |
+
nn.init.constant_(self.init_query_bbox.weight[:, 5:6], 1.5)
|
| 55 |
+
|
| 56 |
+
grid_size = int(math.sqrt(self.num_query))
|
| 57 |
+
assert grid_size * grid_size == self.num_query
|
| 58 |
+
x = y = torch.arange(grid_size)
|
| 59 |
+
xx, yy = torch.meshgrid(x, y, indexing='ij') # [0, grid_size - 1]
|
| 60 |
+
xy = torch.cat([xx[..., None], yy[..., None]], dim=-1)
|
| 61 |
+
xy = (xy + 0.5) / grid_size # [0.5, grid_size - 0.5] / grid_size ~= (0, 1)
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
self.init_query_bbox.weight[:, :2] = xy.reshape(-1, 2) # [Q, 2]
|
| 64 |
+
|
| 65 |
+
def init_weights(self):
|
| 66 |
+
self.transformer.init_weights()
|
| 67 |
+
|
| 68 |
+
def forward(self, mlvl_feats, img_metas):
|
| 69 |
+
query_bbox = self.init_query_bbox.weight.clone() # [Q, 10]
|
| 70 |
+
#query_bbox[..., :3] = query_bbox[..., :3].sigmoid()
|
| 71 |
+
|
| 72 |
+
B = mlvl_feats[0].shape[0]
|
| 73 |
+
query_bbox, query_feat, attn_mask, mask_dict = self.prepare_for_dn_input(B, query_bbox, self.label_enc, img_metas)
|
| 74 |
+
|
| 75 |
+
cls_scores, bbox_preds = self.transformer(
|
| 76 |
+
query_bbox,
|
| 77 |
+
query_feat,
|
| 78 |
+
mlvl_feats,
|
| 79 |
+
attn_mask=attn_mask,
|
| 80 |
+
img_metas=img_metas,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
bbox_preds[..., 0] = bbox_preds[..., 0] * (self.pc_range[3] - self.pc_range[0]) + self.pc_range[0]
|
| 84 |
+
bbox_preds[..., 1] = bbox_preds[..., 1] * (self.pc_range[4] - self.pc_range[1]) + self.pc_range[1]
|
| 85 |
+
bbox_preds[..., 2] = bbox_preds[..., 2] * (self.pc_range[5] - self.pc_range[2]) + self.pc_range[2]
|
| 86 |
+
|
| 87 |
+
bbox_preds = torch.cat([
|
| 88 |
+
bbox_preds[..., 0:2],
|
| 89 |
+
bbox_preds[..., 3:5],
|
| 90 |
+
bbox_preds[..., 2:3],
|
| 91 |
+
bbox_preds[..., 5:10],
|
| 92 |
+
], dim=-1) # [cx, cy, w, l, cz, h, sin, cos, vx, vy]
|
| 93 |
+
|
| 94 |
+
if mask_dict is not None and mask_dict['pad_size'] > 0:
|
| 95 |
+
output_known_cls_scores = cls_scores[:, :, :mask_dict['pad_size'], :]
|
| 96 |
+
output_known_bbox_preds = bbox_preds[:, :, :mask_dict['pad_size'], :]
|
| 97 |
+
output_cls_scores = cls_scores[:, :, mask_dict['pad_size']:, :]
|
| 98 |
+
output_bbox_preds = bbox_preds[:, :, mask_dict['pad_size']:, :]
|
| 99 |
+
mask_dict['output_known_lbs_bboxes'] = (output_known_cls_scores, output_known_bbox_preds)
|
| 100 |
+
outs = {
|
| 101 |
+
'all_cls_scores': output_cls_scores,
|
| 102 |
+
'all_bbox_preds': output_bbox_preds,
|
| 103 |
+
'enc_cls_scores': None,
|
| 104 |
+
'enc_bbox_preds': None,
|
| 105 |
+
'dn_mask_dict': mask_dict,
|
| 106 |
+
}
|
| 107 |
+
else:
|
| 108 |
+
outs = {
|
| 109 |
+
'all_cls_scores': cls_scores,
|
| 110 |
+
'all_bbox_preds': bbox_preds,
|
| 111 |
+
'enc_cls_scores': None,
|
| 112 |
+
'enc_bbox_preds': None,
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
return outs
|
| 116 |
+
|
| 117 |
+
def prepare_for_dn_input(self, batch_size, init_query_bbox, label_enc, img_metas):
|
| 118 |
+
device = init_query_bbox.device
|
| 119 |
+
indicator0 = torch.zeros([self.num_query, 1], device=device)
|
| 120 |
+
init_query_feat = label_enc.weight[self.num_classes].repeat(self.num_query, 1)
|
| 121 |
+
init_query_feat = torch.cat([init_query_feat, indicator0], dim=1)
|
| 122 |
+
|
| 123 |
+
if self.training and self.dn_enabled:
|
| 124 |
+
targets = [{
|
| 125 |
+
'bboxes': torch.cat([m['gt_bboxes_3d'].gravity_center,
|
| 126 |
+
m['gt_bboxes_3d'].tensor[:, 3:]], dim=1).cuda(),
|
| 127 |
+
'labels': m['gt_labels_3d'].cuda().long()
|
| 128 |
+
} for m in img_metas]
|
| 129 |
+
|
| 130 |
+
known = [torch.ones_like(t['labels'], device=device) for t in targets]
|
| 131 |
+
known_num = [sum(k) for k in known]
|
| 132 |
+
|
| 133 |
+
# can be modified to selectively denosie some label or boxes; also known label prediction
|
| 134 |
+
unmask_bbox = unmask_label = torch.cat(known)
|
| 135 |
+
labels = torch.cat([t['labels'] for t in targets]).clone()
|
| 136 |
+
bboxes = torch.cat([t['bboxes'] for t in targets]).clone()
|
| 137 |
+
batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)])
|
| 138 |
+
|
| 139 |
+
known_indice = torch.nonzero(unmask_label + unmask_bbox)
|
| 140 |
+
known_indice = known_indice.view(-1)
|
| 141 |
+
|
| 142 |
+
# add noise
|
| 143 |
+
known_indice = known_indice.repeat(self.dn_group_num, 1).view(-1)
|
| 144 |
+
known_labels = labels.repeat(self.dn_group_num, 1).view(-1)
|
| 145 |
+
known_bid = batch_idx.repeat(self.dn_group_num, 1).view(-1)
|
| 146 |
+
known_bboxs = bboxes.repeat(self.dn_group_num, 1) # 9
|
| 147 |
+
known_labels_expand = known_labels.clone()
|
| 148 |
+
known_bbox_expand = known_bboxs.clone()
|
| 149 |
+
|
| 150 |
+
# noise on the box
|
| 151 |
+
if self.dn_bbox_noise_scale > 0:
|
| 152 |
+
wlh = known_bbox_expand[..., 3:6].clone()
|
| 153 |
+
rand_prob = torch.rand_like(known_bbox_expand) * 2 - 1.0
|
| 154 |
+
known_bbox_expand[..., 0:3] += torch.mul(rand_prob[..., 0:3], wlh / 2) * self.dn_bbox_noise_scale
|
| 155 |
+
# known_bbox_expand[..., 3:6] += torch.mul(rand_prob[..., 3:6], wlh) * self.dn_bbox_noise_scale
|
| 156 |
+
# known_bbox_expand[..., 6:7] += torch.mul(rand_prob[..., 6:7], 3.14159) * self.dn_bbox_noise_scale
|
| 157 |
+
|
| 158 |
+
known_bbox_expand = encode_bbox(known_bbox_expand, self.pc_range)
|
| 159 |
+
known_bbox_expand[..., 0:3].clamp_(min=0.0, max=1.0)
|
| 160 |
+
# nn.init.constant(known_bbox_expand[..., 8:10], 0.0)
|
| 161 |
+
|
| 162 |
+
# noise on the label
|
| 163 |
+
if self.dn_label_noise_scale > 0:
|
| 164 |
+
p = torch.rand_like(known_labels_expand.float())
|
| 165 |
+
chosen_indice = torch.nonzero(p < self.dn_label_noise_scale).view(-1) # usually half of bbox noise
|
| 166 |
+
new_label = torch.randint_like(chosen_indice, 0, self.num_classes) # randomly put a new one here
|
| 167 |
+
known_labels_expand.scatter_(0, chosen_indice, new_label)
|
| 168 |
+
|
| 169 |
+
known_feat_expand = label_enc(known_labels_expand)
|
| 170 |
+
indicator1 = torch.ones([known_feat_expand.shape[0], 1], device=device) # add dn part indicator
|
| 171 |
+
known_feat_expand = torch.cat([known_feat_expand, indicator1], dim=1)
|
| 172 |
+
|
| 173 |
+
# construct final query
|
| 174 |
+
dn_single_pad = int(max(known_num))
|
| 175 |
+
dn_pad_size = int(dn_single_pad * self.dn_group_num)
|
| 176 |
+
dn_query_bbox = torch.zeros([dn_pad_size, init_query_bbox.shape[-1]], device=device)
|
| 177 |
+
dn_query_feat = torch.zeros([dn_pad_size, self.embed_dims], device=device)
|
| 178 |
+
input_query_bbox = torch.cat([dn_query_bbox, init_query_bbox], dim=0).repeat(batch_size, 1, 1)
|
| 179 |
+
input_query_feat = torch.cat([dn_query_feat, init_query_feat], dim=0).repeat(batch_size, 1, 1)
|
| 180 |
+
|
| 181 |
+
if len(known_num):
|
| 182 |
+
map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num]) # [1,2, 1,2,3]
|
| 183 |
+
map_known_indice = torch.cat([map_known_indice + dn_single_pad * i for i in range(self.dn_group_num)]).long()
|
| 184 |
+
|
| 185 |
+
if len(known_bid):
|
| 186 |
+
input_query_bbox[known_bid.long(), map_known_indice] = known_bbox_expand
|
| 187 |
+
input_query_feat[(known_bid.long(), map_known_indice)] = known_feat_expand
|
| 188 |
+
|
| 189 |
+
total_size = dn_pad_size + self.num_query
|
| 190 |
+
attn_mask = torch.ones([total_size, total_size], device=device) < 0
|
| 191 |
+
|
| 192 |
+
# match query cannot see the reconstruct
|
| 193 |
+
attn_mask[dn_pad_size:, :dn_pad_size] = True
|
| 194 |
+
for i in range(self.dn_group_num):
|
| 195 |
+
if i == 0:
|
| 196 |
+
attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), dn_single_pad * (i + 1):dn_pad_size] = True
|
| 197 |
+
if i == self.dn_group_num - 1:
|
| 198 |
+
attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), :dn_single_pad * i] = True
|
| 199 |
+
else:
|
| 200 |
+
attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), dn_single_pad * (i + 1):dn_pad_size] = True
|
| 201 |
+
attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), :dn_single_pad * i] = True
|
| 202 |
+
|
| 203 |
+
mask_dict = {
|
| 204 |
+
'known_indice': torch.as_tensor(known_indice).long(),
|
| 205 |
+
'batch_idx': torch.as_tensor(batch_idx).long(),
|
| 206 |
+
'map_known_indice': torch.as_tensor(map_known_indice).long(),
|
| 207 |
+
'known_lbs_bboxes': (known_labels, known_bboxs),
|
| 208 |
+
'pad_size': dn_pad_size
|
| 209 |
+
}
|
| 210 |
+
else:
|
| 211 |
+
input_query_bbox = init_query_bbox.repeat(batch_size, 1, 1)
|
| 212 |
+
input_query_feat = init_query_feat.repeat(batch_size, 1, 1)
|
| 213 |
+
attn_mask = None
|
| 214 |
+
mask_dict = None
|
| 215 |
+
|
| 216 |
+
return input_query_bbox, input_query_feat, attn_mask, mask_dict
|
| 217 |
+
|
| 218 |
+
def prepare_for_dn_loss(self, mask_dict):
|
| 219 |
+
cls_scores, bbox_preds = mask_dict['output_known_lbs_bboxes']
|
| 220 |
+
known_labels, known_bboxs = mask_dict['known_lbs_bboxes']
|
| 221 |
+
map_known_indice = mask_dict['map_known_indice'].long()
|
| 222 |
+
known_indice = mask_dict['known_indice'].long()
|
| 223 |
+
batch_idx = mask_dict['batch_idx'].long()
|
| 224 |
+
bid = batch_idx[known_indice]
|
| 225 |
+
num_tgt = known_indice.numel()
|
| 226 |
+
|
| 227 |
+
if len(cls_scores) > 0:
|
| 228 |
+
cls_scores = cls_scores.permute(1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2)
|
| 229 |
+
bbox_preds = bbox_preds.permute(1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2)
|
| 230 |
+
|
| 231 |
+
return known_labels, known_bboxs, cls_scores, bbox_preds, num_tgt
|
| 232 |
+
|
| 233 |
+
def dn_loss_single(self,
|
| 234 |
+
cls_scores,
|
| 235 |
+
bbox_preds,
|
| 236 |
+
known_bboxs,
|
| 237 |
+
known_labels,
|
| 238 |
+
num_total_pos=None):
|
| 239 |
+
# Compute the average number of gt boxes accross all gpus
|
| 240 |
+
num_total_pos = cls_scores.new_tensor([num_total_pos])
|
| 241 |
+
num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1.0).item()
|
| 242 |
+
|
| 243 |
+
# cls loss
|
| 244 |
+
cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
|
| 245 |
+
bbox_weights = torch.ones_like(bbox_preds)
|
| 246 |
+
label_weights = torch.ones_like(known_labels)
|
| 247 |
+
loss_cls = self.loss_cls(
|
| 248 |
+
cls_scores,
|
| 249 |
+
known_labels.long(),
|
| 250 |
+
label_weights,
|
| 251 |
+
avg_factor=num_total_pos
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# regression L1 loss
|
| 255 |
+
bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1))
|
| 256 |
+
normalized_bbox_targets = normalize_bbox(known_bboxs)
|
| 257 |
+
isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1)
|
| 258 |
+
bbox_weights = bbox_weights * self.code_weights
|
| 259 |
+
loss_bbox = self.loss_bbox(
|
| 260 |
+
bbox_preds[isnotnan, :10],
|
| 261 |
+
normalized_bbox_targets[isnotnan, :10],
|
| 262 |
+
bbox_weights[isnotnan, :10],
|
| 263 |
+
avg_factor=num_total_pos
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
loss_cls = self.dn_weight * torch.nan_to_num(loss_cls)
|
| 267 |
+
loss_bbox = self.dn_weight * torch.nan_to_num(loss_bbox)
|
| 268 |
+
|
| 269 |
+
return loss_cls, loss_bbox
|
| 270 |
+
|
| 271 |
+
@force_fp32(apply_to=('preds_dicts'))
|
| 272 |
+
def calc_dn_loss(self, loss_dict, preds_dicts, num_dec_layers):
|
| 273 |
+
known_labels, known_bboxs, cls_scores, bbox_preds, num_tgt = \
|
| 274 |
+
self.prepare_for_dn_loss(preds_dicts['dn_mask_dict'])
|
| 275 |
+
|
| 276 |
+
all_known_bboxs_list = [known_bboxs for _ in range(num_dec_layers)]
|
| 277 |
+
all_known_labels_list = [known_labels for _ in range(num_dec_layers)]
|
| 278 |
+
all_num_tgts_list = [num_tgt for _ in range(num_dec_layers)]
|
| 279 |
+
|
| 280 |
+
dn_losses_cls, dn_losses_bbox = multi_apply(
|
| 281 |
+
self.dn_loss_single, cls_scores, bbox_preds,
|
| 282 |
+
all_known_bboxs_list, all_known_labels_list, all_num_tgts_list)
|
| 283 |
+
|
| 284 |
+
loss_dict['loss_cls_dn'] = dn_losses_cls[-1]
|
| 285 |
+
loss_dict['loss_bbox_dn'] = dn_losses_bbox[-1]
|
| 286 |
+
|
| 287 |
+
num_dec_layer = 0
|
| 288 |
+
for loss_cls_i, loss_bbox_i in zip(dn_losses_cls[:-1], dn_losses_bbox[:-1]):
|
| 289 |
+
loss_dict[f'd{num_dec_layer}.loss_cls_dn'] = loss_cls_i
|
| 290 |
+
loss_dict[f'd{num_dec_layer}.loss_bbox_dn'] = loss_bbox_i
|
| 291 |
+
num_dec_layer += 1
|
| 292 |
+
|
| 293 |
+
return loss_dict
|
| 294 |
+
|
| 295 |
+
def _get_target_single(self,
|
| 296 |
+
cls_score,
|
| 297 |
+
bbox_pred,
|
| 298 |
+
gt_labels,
|
| 299 |
+
gt_bboxes,
|
| 300 |
+
gt_bboxes_ignore=None):
|
| 301 |
+
num_bboxes = bbox_pred.size(0)
|
| 302 |
+
|
| 303 |
+
# assigner and sampler
|
| 304 |
+
assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes, gt_labels, gt_bboxes_ignore, self.code_weights, True)
|
| 305 |
+
sampling_result = self.sampler.sample(assign_result, bbox_pred, gt_bboxes)
|
| 306 |
+
pos_inds = sampling_result.pos_inds
|
| 307 |
+
neg_inds = sampling_result.neg_inds
|
| 308 |
+
|
| 309 |
+
# label targets
|
| 310 |
+
labels = gt_bboxes.new_full((num_bboxes, ), self.num_classes, dtype=torch.long)
|
| 311 |
+
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
|
| 312 |
+
label_weights = gt_bboxes.new_ones(num_bboxes)
|
| 313 |
+
|
| 314 |
+
# bbox targets
|
| 315 |
+
bbox_targets = torch.zeros_like(bbox_pred)[..., :9]
|
| 316 |
+
bbox_weights = torch.zeros_like(bbox_pred)
|
| 317 |
+
bbox_weights[pos_inds] = 1.0
|
| 318 |
+
|
| 319 |
+
# DETR
|
| 320 |
+
bbox_targets[pos_inds] = sampling_result.pos_gt_bboxes
|
| 321 |
+
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, neg_inds)
|
| 322 |
+
|
| 323 |
+
def get_targets(self,
|
| 324 |
+
cls_scores_list,
|
| 325 |
+
bbox_preds_list,
|
| 326 |
+
gt_bboxes_list,
|
| 327 |
+
gt_labels_list,
|
| 328 |
+
gt_bboxes_ignore_list=None):
|
| 329 |
+
assert gt_bboxes_ignore_list is None, \
|
| 330 |
+
'Only supports for gt_bboxes_ignore setting to None.'
|
| 331 |
+
num_imgs = len(cls_scores_list)
|
| 332 |
+
gt_bboxes_ignore_list = [gt_bboxes_ignore_list for _ in range(num_imgs)]
|
| 333 |
+
|
| 334 |
+
(labels_list, label_weights_list, bbox_targets_list,
|
| 335 |
+
bbox_weights_list, pos_inds_list, neg_inds_list) = multi_apply(
|
| 336 |
+
self._get_target_single, cls_scores_list, bbox_preds_list,
|
| 337 |
+
gt_labels_list, gt_bboxes_list, gt_bboxes_ignore_list)
|
| 338 |
+
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
|
| 339 |
+
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
|
| 340 |
+
return (labels_list, label_weights_list, bbox_targets_list,
|
| 341 |
+
bbox_weights_list, num_total_pos, num_total_neg)
|
| 342 |
+
|
| 343 |
+
def loss_single(self,
|
| 344 |
+
cls_scores,
|
| 345 |
+
bbox_preds,
|
| 346 |
+
gt_bboxes_list,
|
| 347 |
+
gt_labels_list,
|
| 348 |
+
gt_bboxes_ignore_list=None):
|
| 349 |
+
num_imgs = cls_scores.size(0)
|
| 350 |
+
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
|
| 351 |
+
bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
|
| 352 |
+
cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list,
|
| 353 |
+
gt_bboxes_list, gt_labels_list, gt_bboxes_ignore_list)
|
| 354 |
+
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
|
| 355 |
+
num_total_pos, num_total_neg) = cls_reg_targets
|
| 356 |
+
|
| 357 |
+
labels = torch.cat(labels_list, 0)
|
| 358 |
+
label_weights = torch.cat(label_weights_list, 0)
|
| 359 |
+
bbox_targets = torch.cat(bbox_targets_list, 0)
|
| 360 |
+
bbox_weights = torch.cat(bbox_weights_list, 0)
|
| 361 |
+
|
| 362 |
+
# classification loss
|
| 363 |
+
cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
|
| 364 |
+
# construct weighted avg_factor to match with the official DETR repo
|
| 365 |
+
cls_avg_factor = num_total_pos * 1.0 + \
|
| 366 |
+
num_total_neg * self.bg_cls_weight
|
| 367 |
+
if self.sync_cls_avg_factor:
|
| 368 |
+
cls_avg_factor = reduce_mean(
|
| 369 |
+
cls_scores.new_tensor([cls_avg_factor]))
|
| 370 |
+
|
| 371 |
+
cls_avg_factor = max(cls_avg_factor, 1)
|
| 372 |
+
loss_cls = self.loss_cls(
|
| 373 |
+
cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
|
| 374 |
+
|
| 375 |
+
# Compute the average number of gt boxes accross all gpus, for
|
| 376 |
+
# normalization purposes
|
| 377 |
+
num_total_pos = loss_cls.new_tensor([num_total_pos])
|
| 378 |
+
num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
|
| 379 |
+
|
| 380 |
+
# regression L1 loss
|
| 381 |
+
bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1))
|
| 382 |
+
normalized_bbox_targets = normalize_bbox(bbox_targets)
|
| 383 |
+
isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1)
|
| 384 |
+
bbox_weights = bbox_weights * self.code_weights
|
| 385 |
+
|
| 386 |
+
loss_bbox = self.loss_bbox(
|
| 387 |
+
bbox_preds[isnotnan, :10],
|
| 388 |
+
normalized_bbox_targets[isnotnan, :10],
|
| 389 |
+
bbox_weights[isnotnan, :10],
|
| 390 |
+
avg_factor=num_total_pos
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
loss_cls = torch.nan_to_num(loss_cls)
|
| 394 |
+
loss_bbox = torch.nan_to_num(loss_bbox)
|
| 395 |
+
|
| 396 |
+
return loss_cls, loss_bbox
|
| 397 |
+
|
| 398 |
+
@force_fp32(apply_to=('preds_dicts'))
|
| 399 |
+
def loss(self,
|
| 400 |
+
gt_bboxes_list,
|
| 401 |
+
gt_labels_list,
|
| 402 |
+
preds_dicts,
|
| 403 |
+
gt_bboxes_ignore=None):
|
| 404 |
+
assert gt_bboxes_ignore is None, \
|
| 405 |
+
f'{self.__class__.__name__} only supports ' \
|
| 406 |
+
f'for gt_bboxes_ignore setting to None.'
|
| 407 |
+
|
| 408 |
+
all_cls_scores = preds_dicts['all_cls_scores']
|
| 409 |
+
all_bbox_preds = preds_dicts['all_bbox_preds']
|
| 410 |
+
enc_cls_scores = preds_dicts['enc_cls_scores']
|
| 411 |
+
enc_bbox_preds = preds_dicts['enc_bbox_preds']
|
| 412 |
+
|
| 413 |
+
num_dec_layers = len(all_cls_scores)
|
| 414 |
+
device = gt_labels_list[0].device
|
| 415 |
+
gt_bboxes_list = [torch.cat(
|
| 416 |
+
(gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]),
|
| 417 |
+
dim=1).to(device) for gt_bboxes in gt_bboxes_list]
|
| 418 |
+
|
| 419 |
+
all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
|
| 420 |
+
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
|
| 421 |
+
all_gt_bboxes_ignore_list = [gt_bboxes_ignore for _ in range(num_dec_layers)]
|
| 422 |
+
|
| 423 |
+
losses_cls, losses_bbox = multi_apply(
|
| 424 |
+
self.loss_single, all_cls_scores, all_bbox_preds,
|
| 425 |
+
all_gt_bboxes_list, all_gt_labels_list,
|
| 426 |
+
all_gt_bboxes_ignore_list)
|
| 427 |
+
|
| 428 |
+
loss_dict = dict()
|
| 429 |
+
# loss of proposal generated from encode feature map
|
| 430 |
+
if enc_cls_scores is not None:
|
| 431 |
+
binary_labels_list = [
|
| 432 |
+
torch.zeros_like(gt_labels_list[i])
|
| 433 |
+
for i in range(len(all_gt_labels_list))
|
| 434 |
+
]
|
| 435 |
+
enc_loss_cls, enc_losses_bbox = \
|
| 436 |
+
self.loss_single(enc_cls_scores, enc_bbox_preds,
|
| 437 |
+
gt_bboxes_list, binary_labels_list, gt_bboxes_ignore)
|
| 438 |
+
loss_dict['enc_loss_cls'] = enc_loss_cls
|
| 439 |
+
loss_dict['enc_loss_bbox'] = enc_losses_bbox
|
| 440 |
+
|
| 441 |
+
if 'dn_mask_dict' in preds_dicts and preds_dicts['dn_mask_dict'] is not None:
|
| 442 |
+
loss_dict = self.calc_dn_loss(loss_dict, preds_dicts, num_dec_layers)
|
| 443 |
+
|
| 444 |
+
# loss from the last decoder layer
|
| 445 |
+
loss_dict['loss_cls'] = losses_cls[-1]
|
| 446 |
+
loss_dict['loss_bbox'] = losses_bbox[-1]
|
| 447 |
+
|
| 448 |
+
# loss from other decoder layers
|
| 449 |
+
num_dec_layer = 0
|
| 450 |
+
for loss_cls_i, loss_bbox_i in zip(losses_cls[:-1], losses_bbox[:-1]):
|
| 451 |
+
loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
|
| 452 |
+
loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
|
| 453 |
+
num_dec_layer += 1
|
| 454 |
+
return loss_dict
|
| 455 |
+
|
| 456 |
+
@force_fp32(apply_to=('preds_dicts'))
|
| 457 |
+
def get_bboxes(self, preds_dicts, img_metas, rescale=False):
|
| 458 |
+
preds_dicts = self.bbox_coder.decode(preds_dicts)
|
| 459 |
+
num_samples = len(preds_dicts)
|
| 460 |
+
ret_list = []
|
| 461 |
+
for i in range(num_samples):
|
| 462 |
+
preds = preds_dicts[i]
|
| 463 |
+
bboxes = preds['bboxes']
|
| 464 |
+
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
|
| 465 |
+
bboxes = LiDARInstance3DBoxes(bboxes, 9)
|
| 466 |
+
scores = preds['scores']
|
| 467 |
+
labels = preds['labels']
|
| 468 |
+
ret_list.append([bboxes, scores, labels])
|
| 469 |
+
return ret_list
|
models/sparsebev_sampling.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from .bbox.utils import decode_bbox
|
| 4 |
+
from .utils import rotation_3d_in_axis, DUMP
|
| 5 |
+
from .csrc.wrapper import msmv_sampling, msmv_sampling_pytorch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def make_sample_points(query_bbox, offset, pc_range):
|
| 9 |
+
'''
|
| 10 |
+
query_bbox: [B, Q, 10]
|
| 11 |
+
offset: [B, Q, num_points, 4], normalized by stride
|
| 12 |
+
'''
|
| 13 |
+
query_bbox = decode_bbox(query_bbox, pc_range) # [B, Q, 9]
|
| 14 |
+
|
| 15 |
+
xyz = query_bbox[..., 0:3] # [B, Q, 3]
|
| 16 |
+
wlh = query_bbox[..., 3:6] # [B, Q, 3]
|
| 17 |
+
ang = query_bbox[..., 6:7] # [B, Q, 1]
|
| 18 |
+
|
| 19 |
+
delta_xyz = offset[..., 0:3] # [B, Q, P, 3]
|
| 20 |
+
delta_xyz = wlh[:, :, None, :] * delta_xyz # [B, Q, P, 3]
|
| 21 |
+
delta_xyz = rotation_3d_in_axis(delta_xyz, ang) # [B, Q, P, 3]
|
| 22 |
+
sample_xyz = xyz[:, :, None, :] + delta_xyz # [B, Q, P, 3]
|
| 23 |
+
|
| 24 |
+
return sample_xyz # [B, Q, P, 3]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def sampling_4d(sample_points, mlvl_feats, scale_weights, lidar2img, image_h, image_w, eps=1e-5):
|
| 28 |
+
B, Q, T, G, P, _ = sample_points.shape # [B, Q, T, G, P, 4]
|
| 29 |
+
N = 6
|
| 30 |
+
|
| 31 |
+
sample_points = sample_points.reshape(B, Q, T, G * P, 3)
|
| 32 |
+
|
| 33 |
+
# get the projection matrix
|
| 34 |
+
lidar2img = lidar2img[:, :, None, None, :, :] # [B, TN, 1, 1, 4, 4]
|
| 35 |
+
lidar2img = lidar2img.expand(B, T*N, Q, G * P, 4, 4)
|
| 36 |
+
lidar2img = lidar2img.reshape(B, T, N, Q, G*P, 4, 4)
|
| 37 |
+
|
| 38 |
+
# expand the points
|
| 39 |
+
ones = torch.ones_like(sample_points[..., :1])
|
| 40 |
+
sample_points = torch.cat([sample_points, ones], dim=-1) # [B, Q, GP, 4]
|
| 41 |
+
sample_points = sample_points[:, :, None, ..., None] # [B, Q, T, GP, 4]
|
| 42 |
+
sample_points = sample_points.expand(B, Q, N, T, G * P, 4, 1)
|
| 43 |
+
sample_points = sample_points.transpose(1, 3) # [B, T, N, Q, GP, 4, 1]
|
| 44 |
+
|
| 45 |
+
# project 3d sampling points to image
|
| 46 |
+
sample_points_cam = torch.matmul(lidar2img, sample_points).squeeze(-1) # [B, T, N, Q, GP, 4]
|
| 47 |
+
|
| 48 |
+
# homo coord -> pixel coord
|
| 49 |
+
homo = sample_points_cam[..., 2:3]
|
| 50 |
+
homo_nonzero = torch.maximum(homo, torch.zeros_like(homo) + eps)
|
| 51 |
+
sample_points_cam = sample_points_cam[..., 0:2] / homo_nonzero # [B, T, N, Q, GP, 2]
|
| 52 |
+
|
| 53 |
+
# normalize
|
| 54 |
+
sample_points_cam[..., 0] /= image_w
|
| 55 |
+
sample_points_cam[..., 1] /= image_h
|
| 56 |
+
|
| 57 |
+
# check if out of image
|
| 58 |
+
valid_mask = ((homo > eps) \
|
| 59 |
+
& (sample_points_cam[..., 1:2] > 0.0)
|
| 60 |
+
& (sample_points_cam[..., 1:2] < 1.0)
|
| 61 |
+
& (sample_points_cam[..., 0:1] > 0.0)
|
| 62 |
+
& (sample_points_cam[..., 0:1] < 1.0)
|
| 63 |
+
).squeeze(-1).float() # [B, T, N, Q, GP]
|
| 64 |
+
|
| 65 |
+
if DUMP.enabled:
|
| 66 |
+
torch.save(torch.cat([sample_points_cam, homo_nonzero], dim=-1),
|
| 67 |
+
'{}/sample_points_cam_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
|
| 68 |
+
torch.save(valid_mask,
|
| 69 |
+
'{}/sample_points_cam_valid_mask_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
|
| 70 |
+
|
| 71 |
+
valid_mask = valid_mask.permute(0, 1, 3, 4, 2) # [B, T, Q, GP, N]
|
| 72 |
+
sample_points_cam = sample_points_cam.permute(0, 1, 3, 4, 2, 5) # [B, T, Q, GP, N, 2]
|
| 73 |
+
|
| 74 |
+
i_batch = torch.arange(B, dtype=torch.long, device=sample_points.device)
|
| 75 |
+
i_query = torch.arange(Q, dtype=torch.long, device=sample_points.device)
|
| 76 |
+
i_time = torch.arange(T, dtype=torch.long, device=sample_points.device)
|
| 77 |
+
i_point = torch.arange(G * P, dtype=torch.long, device=sample_points.device)
|
| 78 |
+
i_batch = i_batch.view(B, 1, 1, 1, 1).expand(B, T, Q, G * P, 1)
|
| 79 |
+
i_time = i_time.view(1, T, 1, 1, 1).expand(B, T, Q, G * P, 1)
|
| 80 |
+
i_query = i_query.view(1, 1, Q, 1, 1).expand(B, T, Q, G * P, 1)
|
| 81 |
+
i_point = i_point.view(1, 1, 1, G * P, 1).expand(B, T, Q, G * P, 1)
|
| 82 |
+
i_view = torch.argmax(valid_mask, dim=-1)[..., None] # [B, T, Q, GP, 1]
|
| 83 |
+
|
| 84 |
+
sample_points_cam = sample_points_cam[i_batch, i_time, i_query, i_point, i_view, :] # [B, Q, GP, 1, 2]
|
| 85 |
+
valid_mask = valid_mask[i_batch, i_time, i_query, i_point, i_view] # [B, Q, GP, 1]
|
| 86 |
+
|
| 87 |
+
sample_points_cam = torch.cat([sample_points_cam, i_view[..., None].float() / 5], dim=-1)
|
| 88 |
+
sample_points_cam = sample_points_cam.reshape(B, T, Q, G, P, 1, 3)
|
| 89 |
+
sample_points_cam = sample_points_cam.permute(0, 1, 3, 2, 4, 5, 6) # [B, T, G, Q, P, 1, 3]
|
| 90 |
+
sample_points_cam = sample_points_cam.reshape(B*T*G, Q, P, 3)
|
| 91 |
+
|
| 92 |
+
scale_weights = scale_weights.reshape(B, Q, G, T, P, -1)
|
| 93 |
+
scale_weights = scale_weights.permute(0, 2, 3, 1, 4, 5)
|
| 94 |
+
scale_weights = scale_weights.reshape(B*G*T, Q, P, -1)
|
| 95 |
+
|
| 96 |
+
final = msmv_sampling(mlvl_feats, sample_points_cam, scale_weights)
|
| 97 |
+
C = final.shape[2] # [BTG, Q, C, P]
|
| 98 |
+
final = final.reshape(B, T, G, Q, C, P)
|
| 99 |
+
final = final.permute(0, 3, 2, 1, 5, 4)
|
| 100 |
+
final = final.flatten(3, 4) # [B, Q, G, FP, C]
|
| 101 |
+
|
| 102 |
+
return final
|
models/sparsebev_transformer.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from mmcv.runner import BaseModule
|
| 6 |
+
from mmcv.cnn import bias_init_with_prob
|
| 7 |
+
from mmcv.cnn.bricks.transformer import MultiheadAttention, FFN
|
| 8 |
+
from mmdet.models.utils.builder import TRANSFORMER
|
| 9 |
+
from .bbox.utils import decode_bbox
|
| 10 |
+
from .utils import inverse_sigmoid, DUMP
|
| 11 |
+
from .sparsebev_sampling import sampling_4d, make_sample_points
|
| 12 |
+
from .checkpoint import checkpoint as cp
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@TRANSFORMER.register_module()
|
| 16 |
+
class SparseBEVTransformer(BaseModule):
|
| 17 |
+
def __init__(self, embed_dims, num_frames=8, num_points=4, num_layers=6, num_levels=4, num_classes=10, code_size=10, pc_range=[], init_cfg=None):
|
| 18 |
+
assert init_cfg is None, 'To prevent abnormal initialization ' \
|
| 19 |
+
'behavior, init_cfg is not allowed to be set'
|
| 20 |
+
super(SparseBEVTransformer, self).__init__(init_cfg=init_cfg)
|
| 21 |
+
|
| 22 |
+
self.embed_dims = embed_dims
|
| 23 |
+
self.pc_range = pc_range
|
| 24 |
+
|
| 25 |
+
self.decoder = SparseBEVTransformerDecoder(embed_dims, num_frames, num_points, num_layers, num_levels, num_classes, code_size, pc_range=pc_range)
|
| 26 |
+
|
| 27 |
+
@torch.no_grad()
|
| 28 |
+
def init_weights(self):
|
| 29 |
+
self.decoder.init_weights()
|
| 30 |
+
|
| 31 |
+
def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas):
|
| 32 |
+
cls_scores, bbox_preds = self.decoder(query_bbox, query_feat, mlvl_feats, attn_mask, img_metas)
|
| 33 |
+
|
| 34 |
+
cls_scores = torch.nan_to_num(cls_scores)
|
| 35 |
+
bbox_preds = torch.nan_to_num(bbox_preds)
|
| 36 |
+
|
| 37 |
+
return cls_scores, bbox_preds
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class SparseBEVTransformerDecoder(BaseModule):
|
| 41 |
+
def __init__(self, embed_dims, num_frames=8, num_points=4, num_layers=6, num_levels=4, num_classes=10, code_size=10, pc_range=[], init_cfg=None):
|
| 42 |
+
super(SparseBEVTransformerDecoder, self).__init__(init_cfg)
|
| 43 |
+
self.num_layers = num_layers
|
| 44 |
+
self.pc_range = pc_range
|
| 45 |
+
|
| 46 |
+
self.decoder_layer = SparseBEVTransformerDecoderLayer(
|
| 47 |
+
embed_dims, num_frames, num_points, num_levels, num_classes, code_size, pc_range=pc_range
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
@torch.no_grad()
|
| 51 |
+
def init_weights(self):
|
| 52 |
+
self.decoder_layer.init_weights()
|
| 53 |
+
|
| 54 |
+
def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas):
|
| 55 |
+
cls_scores, bbox_preds = [], []
|
| 56 |
+
|
| 57 |
+
timestamps = np.array([m['img_timestamp'] for m in img_metas], dtype=np.float64)
|
| 58 |
+
timestamps = np.reshape(timestamps, [query_bbox.shape[0], -1, 6])
|
| 59 |
+
time_diff = timestamps[:, :1, :] - timestamps
|
| 60 |
+
time_diff = np.mean(time_diff, axis=-1).astype(np.float32) # [B, F]
|
| 61 |
+
time_diff = torch.from_numpy(time_diff).to(query_bbox.device) # [B, F]
|
| 62 |
+
img_metas[0]['time_diff'] = time_diff
|
| 63 |
+
|
| 64 |
+
lidar2img = np.asarray([m['lidar2img'] for m in img_metas]).astype(np.float32)
|
| 65 |
+
lidar2img = torch.from_numpy(lidar2img).to(query_bbox.device) # [B, N, 4, 4]
|
| 66 |
+
img_metas[0]['lidar2img'] = lidar2img
|
| 67 |
+
|
| 68 |
+
for lvl, feat in enumerate(mlvl_feats):
|
| 69 |
+
B, TN, GC, H, W = feat.shape # [B, TN, GC, H, W]
|
| 70 |
+
N, T, G, C = 6, TN // 6, 4, GC // 4
|
| 71 |
+
feat = feat.reshape(B, T, N, G, C, H, W)
|
| 72 |
+
feat = feat.permute(0, 1, 3, 2, 5, 6, 4) # [B, T, G, N, H, W, C]
|
| 73 |
+
feat = feat.reshape(B*T*G, N, H, W, C) # [BTG, C, N, H, W]
|
| 74 |
+
mlvl_feats[lvl] = feat.contiguous()
|
| 75 |
+
|
| 76 |
+
for i in range(self.num_layers):
|
| 77 |
+
DUMP.stage_count = i
|
| 78 |
+
|
| 79 |
+
query_feat, cls_score, bbox_pred = self.decoder_layer(
|
| 80 |
+
query_bbox, query_feat, mlvl_feats, attn_mask, img_metas
|
| 81 |
+
)
|
| 82 |
+
query_bbox = bbox_pred.clone().detach()
|
| 83 |
+
|
| 84 |
+
cls_scores.append(cls_score)
|
| 85 |
+
bbox_preds.append(bbox_pred)
|
| 86 |
+
|
| 87 |
+
cls_scores = torch.stack(cls_scores)
|
| 88 |
+
bbox_preds = torch.stack(bbox_preds)
|
| 89 |
+
|
| 90 |
+
return cls_scores, bbox_preds
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class SparseBEVTransformerDecoderLayer(BaseModule):
|
| 94 |
+
def __init__(self, embed_dims, num_frames=8, num_points=4, num_levels=4, num_classes=10, code_size=10, num_cls_fcs=2, num_reg_fcs=2, pc_range=[], init_cfg=None):
|
| 95 |
+
super(SparseBEVTransformerDecoderLayer, self).__init__(init_cfg)
|
| 96 |
+
|
| 97 |
+
self.embed_dims = embed_dims
|
| 98 |
+
self.num_classes = num_classes
|
| 99 |
+
self.code_size = code_size
|
| 100 |
+
self.pc_range = pc_range
|
| 101 |
+
|
| 102 |
+
self.position_encoder = nn.Sequential(
|
| 103 |
+
nn.Linear(3, self.embed_dims),
|
| 104 |
+
nn.LayerNorm(self.embed_dims),
|
| 105 |
+
nn.ReLU(inplace=True),
|
| 106 |
+
nn.Linear(self.embed_dims, self.embed_dims),
|
| 107 |
+
nn.LayerNorm(self.embed_dims),
|
| 108 |
+
nn.ReLU(inplace=True),
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
self.self_attn = SparseBEVSelfAttention(embed_dims, num_heads=8, dropout=0.1, pc_range=pc_range)
|
| 112 |
+
self.sampling = SparseBEVSampling(embed_dims, num_frames=num_frames, num_groups=4, num_points=num_points, num_levels=num_levels, pc_range=pc_range)
|
| 113 |
+
self.mixing = AdaptiveMixing(in_dim=embed_dims, in_points=num_points * num_frames, n_groups=4, out_points=128)
|
| 114 |
+
self.ffn = FFN(embed_dims, feedforward_channels=512, ffn_drop=0.1)
|
| 115 |
+
|
| 116 |
+
self.norm1 = nn.LayerNorm(embed_dims)
|
| 117 |
+
self.norm2 = nn.LayerNorm(embed_dims)
|
| 118 |
+
self.norm3 = nn.LayerNorm(embed_dims)
|
| 119 |
+
|
| 120 |
+
cls_branch = []
|
| 121 |
+
for _ in range(num_cls_fcs):
|
| 122 |
+
cls_branch.append(nn.Linear(self.embed_dims, self.embed_dims))
|
| 123 |
+
cls_branch.append(nn.LayerNorm(self.embed_dims))
|
| 124 |
+
cls_branch.append(nn.ReLU(inplace=True))
|
| 125 |
+
cls_branch.append(nn.Linear(self.embed_dims, self.num_classes))
|
| 126 |
+
self.cls_branch = nn.Sequential(*cls_branch)
|
| 127 |
+
|
| 128 |
+
reg_branch = []
|
| 129 |
+
for _ in range(num_reg_fcs):
|
| 130 |
+
reg_branch.append(nn.Linear(self.embed_dims, self.embed_dims))
|
| 131 |
+
reg_branch.append(nn.ReLU(inplace=True))
|
| 132 |
+
reg_branch.append(nn.Linear(self.embed_dims, self.code_size))
|
| 133 |
+
self.reg_branch = nn.Sequential(*reg_branch)
|
| 134 |
+
|
| 135 |
+
@torch.no_grad()
|
| 136 |
+
def init_weights(self):
|
| 137 |
+
self.self_attn.init_weights()
|
| 138 |
+
self.sampling.init_weights()
|
| 139 |
+
self.mixing.init_weights()
|
| 140 |
+
|
| 141 |
+
bias_init = bias_init_with_prob(0.01)
|
| 142 |
+
nn.init.constant_(self.cls_branch[-1].bias, bias_init)
|
| 143 |
+
|
| 144 |
+
def refine_bbox(self, bbox_proposal, bbox_delta):
|
| 145 |
+
xyz = inverse_sigmoid(bbox_proposal[..., 0:3])
|
| 146 |
+
xyz_delta = bbox_delta[..., 0:3]
|
| 147 |
+
xyz_new = torch.sigmoid(xyz_delta + xyz)
|
| 148 |
+
|
| 149 |
+
return torch.cat([xyz_new, bbox_delta[..., 3:]], dim=-1)
|
| 150 |
+
|
| 151 |
+
def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas):
|
| 152 |
+
"""
|
| 153 |
+
query_bbox: [B, Q, 10] [cx, cy, cz, w, h, d, rot.sin, rot.cos, vx, vy]
|
| 154 |
+
"""
|
| 155 |
+
query_pos = self.position_encoder(query_bbox[..., :3])
|
| 156 |
+
query_feat = query_feat + query_pos
|
| 157 |
+
|
| 158 |
+
query_feat = self.norm1(self.self_attn(query_bbox, query_feat, attn_mask))
|
| 159 |
+
sampled_feat = self.sampling(query_bbox, query_feat, mlvl_feats, img_metas)
|
| 160 |
+
query_feat = self.norm2(self.mixing(sampled_feat, query_feat))
|
| 161 |
+
query_feat = self.norm3(self.ffn(query_feat))
|
| 162 |
+
|
| 163 |
+
cls_score = self.cls_branch(query_feat) # [B, Q, num_classes]
|
| 164 |
+
bbox_pred = self.reg_branch(query_feat) # [B, Q, code_size]
|
| 165 |
+
bbox_pred = self.refine_bbox(query_bbox, bbox_pred)
|
| 166 |
+
|
| 167 |
+
time_diff = img_metas[0]['time_diff'] # [B, F]
|
| 168 |
+
if time_diff.shape[1] > 1:
|
| 169 |
+
time_diff = time_diff.clone()
|
| 170 |
+
time_diff[time_diff < 1e-5] = 1.0
|
| 171 |
+
bbox_pred[..., 8:] = bbox_pred[..., 8:] / time_diff[:, 1:2, None]
|
| 172 |
+
|
| 173 |
+
if DUMP.enabled:
|
| 174 |
+
query_bbox_dec = decode_bbox(query_bbox, self.pc_range)
|
| 175 |
+
bbox_pred_dec = decode_bbox(bbox_pred, self.pc_range)
|
| 176 |
+
cls_score_sig = torch.sigmoid(cls_score)
|
| 177 |
+
torch.save(query_bbox_dec, '{}/query_bbox_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
|
| 178 |
+
torch.save(bbox_pred_dec, '{}/bbox_pred_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
|
| 179 |
+
torch.save(cls_score_sig, '{}/cls_score_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
|
| 180 |
+
|
| 181 |
+
return query_feat, cls_score, bbox_pred
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class SparseBEVSelfAttention(BaseModule):
|
| 185 |
+
def __init__(self, embed_dims=256, num_heads=8, dropout=0.1, pc_range=[], init_cfg=None):
|
| 186 |
+
super().__init__(init_cfg)
|
| 187 |
+
self.pc_range = pc_range
|
| 188 |
+
|
| 189 |
+
self.attention = MultiheadAttention(embed_dims, num_heads, dropout, batch_first=True)
|
| 190 |
+
self.gen_tau = nn.Linear(embed_dims, num_heads)
|
| 191 |
+
|
| 192 |
+
@torch.no_grad()
|
| 193 |
+
def init_weights(self):
|
| 194 |
+
nn.init.zeros_(self.gen_tau.weight)
|
| 195 |
+
nn.init.uniform_(self.gen_tau.bias, 0.0, 2.0)
|
| 196 |
+
|
| 197 |
+
def inner_forward(self, query_bbox, query_feat, pre_attn_mask):
|
| 198 |
+
"""
|
| 199 |
+
query_bbox: [B, Q, 10]
|
| 200 |
+
query_feat: [B, Q, C]
|
| 201 |
+
"""
|
| 202 |
+
dist = self.calc_bbox_dists(query_bbox)
|
| 203 |
+
tau = self.gen_tau(query_feat) # [B, Q, 8]
|
| 204 |
+
|
| 205 |
+
if DUMP.enabled:
|
| 206 |
+
torch.save(tau, '{}/sasa_tau_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
|
| 207 |
+
|
| 208 |
+
tau = tau.permute(0, 2, 1) # [B, 8, Q]
|
| 209 |
+
attn_mask = dist[:, None, :, :] * tau[..., None] # [B, 8, Q, Q]
|
| 210 |
+
if pre_attn_mask is not None:
|
| 211 |
+
attn_mask[:, :, pre_attn_mask] = float('-inf')
|
| 212 |
+
attn_mask = attn_mask.flatten(0, 1) # [Bx8, Q, Q]
|
| 213 |
+
return self.attention(query_feat, attn_mask=attn_mask)
|
| 214 |
+
|
| 215 |
+
def forward(self, query_bbox, query_feat, pre_attn_mask):
|
| 216 |
+
if self.training and query_feat.requires_grad:
|
| 217 |
+
return cp(self.inner_forward, query_bbox, query_feat, pre_attn_mask, use_reentrant=False)
|
| 218 |
+
else:
|
| 219 |
+
return self.inner_forward(query_bbox, query_feat, pre_attn_mask)
|
| 220 |
+
|
| 221 |
+
@torch.no_grad()
|
| 222 |
+
def calc_bbox_dists(self, bboxes):
|
| 223 |
+
centers = decode_bbox(bboxes, self.pc_range)[..., :2] # [B, Q, 2]
|
| 224 |
+
|
| 225 |
+
dist = []
|
| 226 |
+
for b in range(centers.shape[0]):
|
| 227 |
+
dist_b = torch.norm(centers[b].reshape(-1, 1, 2) - centers[b].reshape(1, -1, 2), dim=-1)
|
| 228 |
+
dist.append(dist_b[None, ...])
|
| 229 |
+
|
| 230 |
+
dist = torch.cat(dist, dim=0) # [B, Q, Q]
|
| 231 |
+
dist = -dist
|
| 232 |
+
|
| 233 |
+
return dist
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class SparseBEVSampling(BaseModule):
|
| 237 |
+
def __init__(self, embed_dims=256, num_frames=4, num_groups=4, num_points=8, num_levels=4, pc_range=[], init_cfg=None):
|
| 238 |
+
super().__init__(init_cfg)
|
| 239 |
+
|
| 240 |
+
self.num_frames = num_frames
|
| 241 |
+
self.num_points = num_points
|
| 242 |
+
self.num_groups = num_groups
|
| 243 |
+
self.num_levels = num_levels
|
| 244 |
+
self.pc_range = pc_range
|
| 245 |
+
|
| 246 |
+
self.sampling_offset = nn.Linear(embed_dims, num_groups * num_points * 3)
|
| 247 |
+
self.scale_weights = nn.Linear(embed_dims, num_groups * num_points * num_levels)
|
| 248 |
+
|
| 249 |
+
def init_weights(self):
|
| 250 |
+
bias = self.sampling_offset.bias.data.view(self.num_groups * self.num_points, 3)
|
| 251 |
+
nn.init.zeros_(self.sampling_offset.weight)
|
| 252 |
+
nn.init.uniform_(bias[:, 0:3], -0.5, 0.5)
|
| 253 |
+
|
| 254 |
+
def inner_forward(self, query_bbox, query_feat, mlvl_feats, img_metas):
|
| 255 |
+
'''
|
| 256 |
+
query_bbox: [B, Q, 10]
|
| 257 |
+
query_feat: [B, Q, C]
|
| 258 |
+
'''
|
| 259 |
+
B, Q = query_bbox.shape[:2]
|
| 260 |
+
image_h, image_w, _ = img_metas[0]['img_shape'][0]
|
| 261 |
+
|
| 262 |
+
# sampling offset of all frames
|
| 263 |
+
sampling_offset = self.sampling_offset(query_feat)
|
| 264 |
+
sampling_offset = sampling_offset.view(B, Q, self.num_groups * self.num_points, 3)
|
| 265 |
+
sampling_points = make_sample_points(query_bbox, sampling_offset, self.pc_range) # [B, Q, GP, 3]
|
| 266 |
+
sampling_points = sampling_points.reshape(B, Q, 1, self.num_groups, self.num_points, 3)
|
| 267 |
+
sampling_points = sampling_points.expand(B, Q, self.num_frames, self.num_groups, self.num_points, 3)
|
| 268 |
+
|
| 269 |
+
# warp sample points based on velocity
|
| 270 |
+
time_diff = img_metas[0]['time_diff'] # [B, F]
|
| 271 |
+
time_diff = time_diff[:, None, :, None] # [B, 1, F, 1]
|
| 272 |
+
vel = query_bbox[..., 8:].detach() # [B, Q, 2]
|
| 273 |
+
vel = vel[:, :, None, :] # [B, Q, 1, 2]
|
| 274 |
+
dist = vel * time_diff # [B, Q, F, 2]
|
| 275 |
+
dist = dist[:, :, :, None, None, :] # [B, Q, F, 1, 1, 2]
|
| 276 |
+
sampling_points = torch.cat([
|
| 277 |
+
sampling_points[..., 0:2] - dist,
|
| 278 |
+
sampling_points[..., 2:3]
|
| 279 |
+
], dim=-1)
|
| 280 |
+
|
| 281 |
+
# scale weights
|
| 282 |
+
scale_weights = self.scale_weights(query_feat).view(B, Q, self.num_groups, 1, self.num_points, self.num_levels)
|
| 283 |
+
scale_weights = torch.softmax(scale_weights, dim=-1)
|
| 284 |
+
scale_weights = scale_weights.expand(B, Q, self.num_groups, self.num_frames, self.num_points, self.num_levels)
|
| 285 |
+
|
| 286 |
+
# sampling
|
| 287 |
+
sampled_feats = sampling_4d(
|
| 288 |
+
sampling_points,
|
| 289 |
+
mlvl_feats,
|
| 290 |
+
scale_weights,
|
| 291 |
+
img_metas[0]['lidar2img'],
|
| 292 |
+
image_h, image_w
|
| 293 |
+
) # [B, Q, G, FP, C]
|
| 294 |
+
|
| 295 |
+
return sampled_feats
|
| 296 |
+
|
| 297 |
+
def forward(self, query_bbox, query_feat, mlvl_feats, img_metas):
|
| 298 |
+
if self.training and query_feat.requires_grad:
|
| 299 |
+
return cp(self.inner_forward, query_bbox, query_feat, mlvl_feats, img_metas, use_reentrant=False)
|
| 300 |
+
else:
|
| 301 |
+
return self.inner_forward(query_bbox, query_feat, mlvl_feats, img_metas)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class AdaptiveMixing(nn.Module):
|
| 305 |
+
def __init__(self, in_dim, in_points, n_groups=1, query_dim=None, out_dim=None, out_points=None):
|
| 306 |
+
super(AdaptiveMixing, self).__init__()
|
| 307 |
+
|
| 308 |
+
out_dim = out_dim if out_dim is not None else in_dim
|
| 309 |
+
out_points = out_points if out_points is not None else in_points
|
| 310 |
+
query_dim = query_dim if query_dim is not None else in_dim
|
| 311 |
+
|
| 312 |
+
self.query_dim = query_dim
|
| 313 |
+
self.in_dim = in_dim
|
| 314 |
+
self.in_points = in_points
|
| 315 |
+
self.n_groups = n_groups
|
| 316 |
+
self.out_dim = out_dim
|
| 317 |
+
self.out_points = out_points
|
| 318 |
+
|
| 319 |
+
self.eff_in_dim = in_dim // n_groups
|
| 320 |
+
self.eff_out_dim = out_dim // n_groups
|
| 321 |
+
|
| 322 |
+
self.m_parameters = self.eff_in_dim * self.eff_out_dim
|
| 323 |
+
self.s_parameters = self.in_points * self.out_points
|
| 324 |
+
self.total_parameters = self.m_parameters + self.s_parameters
|
| 325 |
+
|
| 326 |
+
self.parameter_generator = nn.Linear(self.query_dim, self.n_groups * self.total_parameters)
|
| 327 |
+
self.out_proj = nn.Linear(self.eff_out_dim * self.out_points * self.n_groups, self.query_dim)
|
| 328 |
+
self.act = nn.ReLU(inplace=True)
|
| 329 |
+
|
| 330 |
+
@torch.no_grad()
|
| 331 |
+
def init_weights(self):
|
| 332 |
+
nn.init.zeros_(self.parameter_generator.weight)
|
| 333 |
+
|
| 334 |
+
def inner_forward(self, x, query):
|
| 335 |
+
B, Q, G, P, C = x.shape
|
| 336 |
+
assert G == self.n_groups
|
| 337 |
+
assert P == self.in_points
|
| 338 |
+
assert C == self.eff_in_dim
|
| 339 |
+
|
| 340 |
+
'''generate mixing parameters'''
|
| 341 |
+
params = self.parameter_generator(query)
|
| 342 |
+
params = params.reshape(B*Q, G, -1)
|
| 343 |
+
out = x.reshape(B*Q, G, P, C)
|
| 344 |
+
|
| 345 |
+
M, S = params.split([self.m_parameters, self.s_parameters], 2)
|
| 346 |
+
M = M.reshape(B*Q, G, self.eff_in_dim, self.eff_out_dim)
|
| 347 |
+
S = S.reshape(B*Q, G, self.out_points, self.in_points)
|
| 348 |
+
|
| 349 |
+
'''adaptive channel mixing'''
|
| 350 |
+
out = torch.matmul(out, M)
|
| 351 |
+
out = F.layer_norm(out, [out.size(-2), out.size(-1)])
|
| 352 |
+
out = self.act(out)
|
| 353 |
+
|
| 354 |
+
'''adaptive point mixing'''
|
| 355 |
+
out = torch.matmul(S, out) # implicitly transpose and matmul
|
| 356 |
+
out = F.layer_norm(out, [out.size(-2), out.size(-1)])
|
| 357 |
+
out = self.act(out)
|
| 358 |
+
|
| 359 |
+
'''linear transfomation to query dim'''
|
| 360 |
+
out = out.reshape(B, Q, -1)
|
| 361 |
+
out = self.out_proj(out)
|
| 362 |
+
out = query + out
|
| 363 |
+
|
| 364 |
+
return out
|
| 365 |
+
|
| 366 |
+
def forward(self, x, query):
|
| 367 |
+
if self.training and x.requires_grad:
|
| 368 |
+
return cp(self.inner_forward, x, query, use_reentrant=False)
|
| 369 |
+
else:
|
| 370 |
+
return self.inner_forward(x, query)
|
models/utils.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
from numpy import random
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class GridMask(nn.Module):
|
| 9 |
+
def __init__(self, ratio=0.5, prob=0.7):
|
| 10 |
+
super(GridMask, self).__init__()
|
| 11 |
+
self.ratio = ratio
|
| 12 |
+
self.prob = prob
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
if np.random.rand() > self.prob or not self.training:
|
| 16 |
+
return x
|
| 17 |
+
|
| 18 |
+
n, c, h, w = x.size()
|
| 19 |
+
x = x.view(-1, h, w)
|
| 20 |
+
hh = int(1.5 * h)
|
| 21 |
+
ww = int(1.5 * w)
|
| 22 |
+
|
| 23 |
+
d = np.random.randint(2, h)
|
| 24 |
+
l = min(max(int(d * self.ratio + 0.5), 1), d - 1)
|
| 25 |
+
mask = np.ones((hh, ww), np.uint8)
|
| 26 |
+
st_h = np.random.randint(d)
|
| 27 |
+
st_w = np.random.randint(d)
|
| 28 |
+
|
| 29 |
+
for i in range(hh // d):
|
| 30 |
+
s = d*i + st_h
|
| 31 |
+
t = min(s + l, hh)
|
| 32 |
+
mask[s:t, :] = 0
|
| 33 |
+
|
| 34 |
+
for i in range(ww // d):
|
| 35 |
+
s = d*i + st_w
|
| 36 |
+
t = min(s + l, ww)
|
| 37 |
+
mask[:, s:t] = 0
|
| 38 |
+
|
| 39 |
+
mask = mask[(hh-h)//2:(hh-h)//2+h, (ww-w)//2:(ww-w)//2+w]
|
| 40 |
+
mask = torch.tensor(mask, dtype=x.dtype, device=x.device)
|
| 41 |
+
mask = 1 - mask
|
| 42 |
+
mask = mask.expand_as(x)
|
| 43 |
+
x = x * mask
|
| 44 |
+
|
| 45 |
+
return x.view(n, c, h, w)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def rotation_3d_in_axis(points, angles):
|
| 49 |
+
assert points.shape[-1] == 3
|
| 50 |
+
assert angles.shape[-1] == 1
|
| 51 |
+
angles = angles[..., 0]
|
| 52 |
+
|
| 53 |
+
n_points = points.shape[-2]
|
| 54 |
+
input_dims = angles.shape
|
| 55 |
+
|
| 56 |
+
if len(input_dims) > 1:
|
| 57 |
+
points = points.reshape(-1, n_points, 3)
|
| 58 |
+
angles = angles.reshape(-1)
|
| 59 |
+
|
| 60 |
+
rot_sin = torch.sin(angles)
|
| 61 |
+
rot_cos = torch.cos(angles)
|
| 62 |
+
ones = torch.ones_like(rot_cos)
|
| 63 |
+
zeros = torch.zeros_like(rot_cos)
|
| 64 |
+
|
| 65 |
+
rot_mat_T = torch.stack([
|
| 66 |
+
rot_cos, rot_sin, zeros,
|
| 67 |
+
-rot_sin, rot_cos, zeros,
|
| 68 |
+
zeros, zeros, ones,
|
| 69 |
+
]).transpose(0, 1).reshape(-1, 3, 3)
|
| 70 |
+
|
| 71 |
+
points = torch.bmm(points, rot_mat_T)
|
| 72 |
+
|
| 73 |
+
if len(input_dims) > 1:
|
| 74 |
+
points = points.reshape(*input_dims, n_points, 3)
|
| 75 |
+
|
| 76 |
+
return points
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def inverse_sigmoid(x, eps=1e-5):
|
| 80 |
+
"""Inverse function of sigmoid.
|
| 81 |
+
Args:
|
| 82 |
+
x (Tensor): The tensor to do the
|
| 83 |
+
inverse.
|
| 84 |
+
eps (float): EPS avoid numerical
|
| 85 |
+
overflow. Defaults 1e-5.
|
| 86 |
+
Returns:
|
| 87 |
+
Tensor: The x has passed the inverse
|
| 88 |
+
function of sigmoid, has same
|
| 89 |
+
shape with input.
|
| 90 |
+
"""
|
| 91 |
+
x = x.clamp(min=0, max=1)
|
| 92 |
+
x1 = x.clamp(min=eps)
|
| 93 |
+
x2 = (1 - x).clamp(min=eps)
|
| 94 |
+
return torch.log(x1 / x2)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def pad_multiple(inputs, img_metas, size_divisor=32):
|
| 98 |
+
_, _, img_h, img_w = inputs.shape
|
| 99 |
+
|
| 100 |
+
pad_h = 0 if img_h % size_divisor == 0 else size_divisor - (img_h % size_divisor)
|
| 101 |
+
pad_w = 0 if img_w % size_divisor == 0 else size_divisor - (img_w % size_divisor)
|
| 102 |
+
|
| 103 |
+
B = len(img_metas)
|
| 104 |
+
N = len(img_metas[0]['ori_shape'])
|
| 105 |
+
|
| 106 |
+
for b in range(B):
|
| 107 |
+
img_metas[b]['img_shape'] = [(img_h + pad_h, img_w + pad_w, 3) for _ in range(N)]
|
| 108 |
+
img_metas[b]['pad_shape'] = [(img_h + pad_h, img_w + pad_w, 3) for _ in range(N)]
|
| 109 |
+
|
| 110 |
+
if pad_h == 0 and pad_w == 0:
|
| 111 |
+
return inputs
|
| 112 |
+
else:
|
| 113 |
+
return F.pad(inputs, [0, pad_w, 0, pad_h], value=0)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def rgb_to_hsv(image: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
| 117 |
+
r"""Convert an image from RGB to HSV.
|
| 118 |
+
|
| 119 |
+
.. image:: _static/img/rgb_to_hsv.png
|
| 120 |
+
|
| 121 |
+
The image data is assumed to be in the range of (0, 1).
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
image: RGB Image to be converted to HSV with shape of :math:`(*, 3, H, W)`.
|
| 125 |
+
eps: scalar to enforce numarical stability.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
HSV version of the image with shape of :math:`(*, 3, H, W)`.
|
| 129 |
+
The H channel values are in the range 0..2pi. S and V are in the range 0..1.
|
| 130 |
+
|
| 131 |
+
.. note::
|
| 132 |
+
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
|
| 133 |
+
color_conversions.html>`__.
|
| 134 |
+
|
| 135 |
+
Example:
|
| 136 |
+
>>> input = torch.rand(2, 3, 4, 5)
|
| 137 |
+
>>> output = rgb_to_hsv(input) # 2x3x4x5
|
| 138 |
+
"""
|
| 139 |
+
if not isinstance(image, torch.Tensor):
|
| 140 |
+
raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
|
| 141 |
+
|
| 142 |
+
if len(image.shape) < 3 or image.shape[-3] != 3:
|
| 143 |
+
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
|
| 144 |
+
|
| 145 |
+
image = image / 255.0
|
| 146 |
+
|
| 147 |
+
max_rgb, argmax_rgb = image.max(-3)
|
| 148 |
+
min_rgb, argmin_rgb = image.min(-3)
|
| 149 |
+
deltac = max_rgb - min_rgb
|
| 150 |
+
|
| 151 |
+
v = max_rgb
|
| 152 |
+
s = deltac / (max_rgb + eps)
|
| 153 |
+
|
| 154 |
+
deltac = torch.where(deltac == 0, torch.ones_like(deltac), deltac)
|
| 155 |
+
rc, gc, bc = torch.unbind((max_rgb.unsqueeze(-3) - image), dim=-3)
|
| 156 |
+
|
| 157 |
+
h1 = bc - gc
|
| 158 |
+
h2 = (rc - bc) + 2.0 * deltac
|
| 159 |
+
h3 = (gc - rc) + 4.0 * deltac
|
| 160 |
+
|
| 161 |
+
h = torch.stack((h1, h2, h3), dim=-3) / deltac.unsqueeze(-3)
|
| 162 |
+
h = torch.gather(h, dim=-3, index=argmax_rgb.unsqueeze(-3)).squeeze(-3)
|
| 163 |
+
h = (h / 6.0) % 1.0
|
| 164 |
+
|
| 165 |
+
h = h * 360.0
|
| 166 |
+
v = v * 255.0
|
| 167 |
+
|
| 168 |
+
return torch.stack((h, s, v), dim=-3)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def hsv_to_rgb(image: torch.Tensor) -> torch.Tensor:
|
| 172 |
+
r"""Convert an image from HSV to RGB.
|
| 173 |
+
|
| 174 |
+
The H channel values are assumed to be in the range 0..2pi. S and V are in the range 0..1.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
image: HSV Image to be converted to HSV with shape of :math:`(*, 3, H, W)`.
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
RGB version of the image with shape of :math:`(*, 3, H, W)`.
|
| 181 |
+
|
| 182 |
+
Example:
|
| 183 |
+
>>> input = torch.rand(2, 3, 4, 5)
|
| 184 |
+
>>> output = hsv_to_rgb(input) # 2x3x4x5
|
| 185 |
+
"""
|
| 186 |
+
if not isinstance(image, torch.Tensor):
|
| 187 |
+
raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
|
| 188 |
+
|
| 189 |
+
if len(image.shape) < 3 or image.shape[-3] != 3:
|
| 190 |
+
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
|
| 191 |
+
|
| 192 |
+
h: torch.Tensor = image[..., 0, :, :] / 360.0
|
| 193 |
+
s: torch.Tensor = image[..., 1, :, :]
|
| 194 |
+
v: torch.Tensor = image[..., 2, :, :] / 255.0
|
| 195 |
+
|
| 196 |
+
hi: torch.Tensor = torch.floor(h * 6) % 6
|
| 197 |
+
f: torch.Tensor = ((h * 6) % 6) - hi
|
| 198 |
+
one: torch.Tensor = torch.tensor(1.0, device=image.device, dtype=image.dtype)
|
| 199 |
+
p: torch.Tensor = v * (one - s)
|
| 200 |
+
q: torch.Tensor = v * (one - f * s)
|
| 201 |
+
t: torch.Tensor = v * (one - (one - f) * s)
|
| 202 |
+
|
| 203 |
+
hi = hi.long()
|
| 204 |
+
indices: torch.Tensor = torch.stack([hi, hi + 6, hi + 12], dim=-3)
|
| 205 |
+
out = torch.stack((v, q, p, p, t, v, t, v, v, q, p, p, p, p, t, v, v, q), dim=-3)
|
| 206 |
+
out = torch.gather(out, -3, indices)
|
| 207 |
+
out = out * 255.0
|
| 208 |
+
|
| 209 |
+
return out
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class GpuPhotoMetricDistortion:
|
| 213 |
+
"""Apply photometric distortion to image sequentially, every transformation
|
| 214 |
+
is applied with a probability of 0.5. The position of random contrast is in
|
| 215 |
+
second or second to last.
|
| 216 |
+
1. random brightness
|
| 217 |
+
2. random contrast (mode 0)
|
| 218 |
+
3. convert color from BGR to HSV
|
| 219 |
+
4. random saturation
|
| 220 |
+
5. random hue
|
| 221 |
+
6. convert color from HSV to BGR
|
| 222 |
+
7. random contrast (mode 1)
|
| 223 |
+
8. randomly swap channels
|
| 224 |
+
Args:
|
| 225 |
+
brightness_delta (int): delta of brightness.
|
| 226 |
+
contrast_range (tuple): range of contrast.
|
| 227 |
+
saturation_range (tuple): range of saturation.
|
| 228 |
+
hue_delta (int): delta of hue.
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
def __init__(self,
|
| 232 |
+
brightness_delta=32,
|
| 233 |
+
contrast_range=(0.5, 1.5),
|
| 234 |
+
saturation_range=(0.5, 1.5),
|
| 235 |
+
hue_delta=18):
|
| 236 |
+
self.brightness_delta = brightness_delta
|
| 237 |
+
self.contrast_lower, self.contrast_upper = contrast_range
|
| 238 |
+
self.saturation_lower, self.saturation_upper = saturation_range
|
| 239 |
+
self.hue_delta = hue_delta
|
| 240 |
+
|
| 241 |
+
def __call__(self, imgs):
|
| 242 |
+
"""Call function to perform photometric distortion on images.
|
| 243 |
+
Args:
|
| 244 |
+
results (dict): Result dict from loading pipeline.
|
| 245 |
+
Returns:
|
| 246 |
+
dict: Result dict with images distorted.
|
| 247 |
+
"""
|
| 248 |
+
imgs = imgs[:, [2, 1, 0], :, :] # BGR to RGB
|
| 249 |
+
|
| 250 |
+
contrast_modes = []
|
| 251 |
+
for _ in range(imgs.shape[0]):
|
| 252 |
+
# mode == 0 --> do random contrast first
|
| 253 |
+
# mode == 1 --> do random contrast last
|
| 254 |
+
contrast_modes.append(random.randint(2))
|
| 255 |
+
|
| 256 |
+
for idx in range(imgs.shape[0]):
|
| 257 |
+
# random brightness
|
| 258 |
+
if random.randint(2):
|
| 259 |
+
delta = random.uniform(-self.brightness_delta, self.brightness_delta)
|
| 260 |
+
imgs[idx] += delta
|
| 261 |
+
|
| 262 |
+
if contrast_modes[idx] == 0:
|
| 263 |
+
if random.randint(2):
|
| 264 |
+
alpha = random.uniform(self.contrast_lower, self.contrast_upper)
|
| 265 |
+
imgs[idx] *= alpha
|
| 266 |
+
|
| 267 |
+
# convert color from BGR to HSV
|
| 268 |
+
imgs = rgb_to_hsv(imgs)
|
| 269 |
+
|
| 270 |
+
for idx in range(imgs.shape[0]):
|
| 271 |
+
# random saturation
|
| 272 |
+
if random.randint(2):
|
| 273 |
+
imgs[idx, 1] *= random.uniform(self.saturation_lower, self.saturation_upper)
|
| 274 |
+
|
| 275 |
+
# random hue
|
| 276 |
+
if random.randint(2):
|
| 277 |
+
imgs[idx, 0] += random.uniform(-self.hue_delta, self.hue_delta)
|
| 278 |
+
|
| 279 |
+
imgs[:, 0][imgs[:, 0] > 360] -= 360
|
| 280 |
+
imgs[:, 0][imgs[:, 0] < 0] += 360
|
| 281 |
+
|
| 282 |
+
# convert color from HSV to BGR
|
| 283 |
+
imgs = hsv_to_rgb(imgs)
|
| 284 |
+
|
| 285 |
+
for idx in range(imgs.shape[0]):
|
| 286 |
+
# random contrast
|
| 287 |
+
if contrast_modes[idx] == 1:
|
| 288 |
+
if random.randint(2):
|
| 289 |
+
alpha = random.uniform(self.contrast_lower, self.contrast_upper)
|
| 290 |
+
imgs[idx] *= alpha
|
| 291 |
+
|
| 292 |
+
# randomly swap channels
|
| 293 |
+
if random.randint(2):
|
| 294 |
+
imgs[idx] = imgs[idx, random.permutation(3)]
|
| 295 |
+
|
| 296 |
+
imgs = imgs[:, [2, 1, 0], :, :] # RGB to BGR
|
| 297 |
+
|
| 298 |
+
return imgs
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class DumpConfig:
|
| 302 |
+
def __init__(self):
|
| 303 |
+
self.enabled = False
|
| 304 |
+
self.out_dir = 'outputs'
|
| 305 |
+
self.stage_count = 0
|
| 306 |
+
self.frame_count = 0
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
DUMP = DumpConfig()
|
timing.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import utils
|
| 3 |
+
import logging
|
| 4 |
+
import argparse
|
| 5 |
+
import importlib
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed
|
| 8 |
+
import torch.backends.cudnn as cudnn
|
| 9 |
+
from mmcv import Config, DictAction
|
| 10 |
+
from mmcv.parallel import MMDataParallel
|
| 11 |
+
from mmcv.runner import load_checkpoint
|
| 12 |
+
from mmdet.apis import set_random_seed
|
| 13 |
+
from mmdet3d.datasets import build_dataset, build_dataloader
|
| 14 |
+
from mmdet3d.models import build_model
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def main():
|
| 18 |
+
parser = argparse.ArgumentParser(description='Validate a detector')
|
| 19 |
+
parser.add_argument('--config', required=True)
|
| 20 |
+
parser.add_argument('--weights', required=True)
|
| 21 |
+
parser.add_argument('--num_warmup', default=10)
|
| 22 |
+
parser.add_argument('--samples', default=500)
|
| 23 |
+
parser.add_argument('--log-interval', default=50, help='interval of logging')
|
| 24 |
+
parser.add_argument('--override', nargs='+', action=DictAction)
|
| 25 |
+
args = parser.parse_args()
|
| 26 |
+
|
| 27 |
+
# parse configs
|
| 28 |
+
cfgs = Config.fromfile(args.config)
|
| 29 |
+
if args.override is not None:
|
| 30 |
+
cfgs.merge_from_dict(args.override)
|
| 31 |
+
|
| 32 |
+
# register custom module
|
| 33 |
+
importlib.import_module('models')
|
| 34 |
+
importlib.import_module('loaders')
|
| 35 |
+
|
| 36 |
+
# MMCV, please shut up
|
| 37 |
+
from mmcv.utils.logging import logger_initialized
|
| 38 |
+
logger_initialized['root'] = logging.Logger(__name__, logging.WARNING)
|
| 39 |
+
logger_initialized['mmcv'] = logging.Logger(__name__, logging.WARNING)
|
| 40 |
+
utils.init_logging(None, cfgs.debug)
|
| 41 |
+
|
| 42 |
+
# you need GPUs
|
| 43 |
+
assert torch.cuda.is_available() and torch.cuda.device_count() == 1
|
| 44 |
+
logging.info('Using GPU: %s' % torch.cuda.get_device_name(0))
|
| 45 |
+
torch.cuda.set_device(0)
|
| 46 |
+
|
| 47 |
+
logging.info('Setting random seed: 0')
|
| 48 |
+
set_random_seed(0, deterministic=True)
|
| 49 |
+
cudnn.benchmark = True
|
| 50 |
+
|
| 51 |
+
logging.info('Loading validation set from %s' % cfgs.data.val.data_root)
|
| 52 |
+
val_dataset = build_dataset(cfgs.data.val)
|
| 53 |
+
val_loader = build_dataloader(
|
| 54 |
+
val_dataset,
|
| 55 |
+
samples_per_gpu=1,
|
| 56 |
+
workers_per_gpu=cfgs.data.workers_per_gpu,
|
| 57 |
+
num_gpus=1,
|
| 58 |
+
dist=False,
|
| 59 |
+
shuffle=False,
|
| 60 |
+
seed=0,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
logging.info('Creating model: %s' % cfgs.model.type)
|
| 64 |
+
model = build_model(cfgs.model)
|
| 65 |
+
model.cuda()
|
| 66 |
+
|
| 67 |
+
assert torch.cuda.device_count() == 1
|
| 68 |
+
model = MMDataParallel(model, [0])
|
| 69 |
+
|
| 70 |
+
logging.info('Loading checkpoint from %s' % args.weights)
|
| 71 |
+
load_checkpoint(
|
| 72 |
+
model, args.weights, map_location='cuda', strict=False,
|
| 73 |
+
logger=logging.Logger(__name__, logging.ERROR)
|
| 74 |
+
)
|
| 75 |
+
model.eval()
|
| 76 |
+
|
| 77 |
+
pure_inf_time = 0
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
for i, data in enumerate(val_loader):
|
| 80 |
+
torch.cuda.synchronize()
|
| 81 |
+
start_time = time.perf_counter()
|
| 82 |
+
|
| 83 |
+
model(return_loss=False, rescale=True, **data)
|
| 84 |
+
|
| 85 |
+
torch.cuda.synchronize()
|
| 86 |
+
elapsed = time.perf_counter() - start_time
|
| 87 |
+
|
| 88 |
+
if i >= args.num_warmup:
|
| 89 |
+
pure_inf_time += elapsed
|
| 90 |
+
if (i + 1) % args.log_interval == 0:
|
| 91 |
+
fps = (i + 1 - args.num_warmup) / pure_inf_time
|
| 92 |
+
print(f'Done sample [{i + 1:<3}/ {args.samples}], '
|
| 93 |
+
f'fps: {fps:.1f} sample / s')
|
| 94 |
+
|
| 95 |
+
if (i + 1) == args.samples:
|
| 96 |
+
break
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if __name__ == '__main__':
|
| 100 |
+
main()
|
train.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import utils
|
| 3 |
+
import shutil
|
| 4 |
+
import logging
|
| 5 |
+
import argparse
|
| 6 |
+
import importlib
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from mmcv import Config, DictAction
|
| 11 |
+
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
| 12 |
+
from mmcv.runner import EpochBasedRunner, build_optimizer, load_checkpoint
|
| 13 |
+
from mmdet.apis import set_random_seed
|
| 14 |
+
from mmdet.core import DistEvalHook, EvalHook
|
| 15 |
+
from mmdet3d.datasets import build_dataset
|
| 16 |
+
from mmdet3d.models import build_model
|
| 17 |
+
from loaders.builder import build_dataloader
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def main():
|
| 21 |
+
parser = argparse.ArgumentParser(description='Train a detector')
|
| 22 |
+
parser.add_argument('--config', required=True)
|
| 23 |
+
parser.add_argument('--override', nargs='+', action=DictAction)
|
| 24 |
+
parser.add_argument('--local_rank', type=int, default=0)
|
| 25 |
+
parser.add_argument('--world_size', type=int, default=1)
|
| 26 |
+
args = parser.parse_args()
|
| 27 |
+
|
| 28 |
+
# parse configs
|
| 29 |
+
cfgs = Config.fromfile(args.config)
|
| 30 |
+
if args.override is not None:
|
| 31 |
+
cfgs.merge_from_dict(args.override)
|
| 32 |
+
|
| 33 |
+
# register custom module
|
| 34 |
+
importlib.import_module('models')
|
| 35 |
+
importlib.import_module('loaders')
|
| 36 |
+
|
| 37 |
+
# MMCV, please shut up
|
| 38 |
+
from mmcv.utils.logging import logger_initialized
|
| 39 |
+
logger_initialized['root'] = logging.Logger(__name__, logging.WARNING)
|
| 40 |
+
logger_initialized['mmcv'] = logging.Logger(__name__, logging.WARNING)
|
| 41 |
+
logger_initialized['mmdet3d'] = logging.Logger(__name__, logging.WARNING)
|
| 42 |
+
|
| 43 |
+
# you need GPUs
|
| 44 |
+
assert torch.cuda.is_available()
|
| 45 |
+
|
| 46 |
+
# determine local_rank and world_size
|
| 47 |
+
if 'LOCAL_RANK' not in os.environ:
|
| 48 |
+
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
| 49 |
+
|
| 50 |
+
if 'WORLD_SIZE' not in os.environ:
|
| 51 |
+
os.environ['WORLD_SIZE'] = str(args.world_size)
|
| 52 |
+
|
| 53 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
| 54 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
| 55 |
+
|
| 56 |
+
if local_rank == 0:
|
| 57 |
+
# resume or start a new run
|
| 58 |
+
if cfgs.resume_from is not None:
|
| 59 |
+
assert os.path.isfile(cfgs.resume_from)
|
| 60 |
+
work_dir = os.path.dirname(cfgs.resume_from)
|
| 61 |
+
else:
|
| 62 |
+
run_name = ''
|
| 63 |
+
if not cfgs.debug:
|
| 64 |
+
run_name = input('Name your run (leave blank for default): ')
|
| 65 |
+
if run_name == '':
|
| 66 |
+
run_name = datetime.now().strftime("%Y-%m-%d/%H-%M-%S")
|
| 67 |
+
|
| 68 |
+
work_dir = os.path.join('outputs', cfgs.model.type, run_name)
|
| 69 |
+
if os.path.exists(work_dir): # must be an empty dir
|
| 70 |
+
if input('Path "%s" already exists, overwrite it? [Y/n] ' % work_dir) == 'n':
|
| 71 |
+
print('Bye.')
|
| 72 |
+
exit(0)
|
| 73 |
+
shutil.rmtree(work_dir)
|
| 74 |
+
|
| 75 |
+
os.makedirs(work_dir, exist_ok=False)
|
| 76 |
+
|
| 77 |
+
# init logging, backup code
|
| 78 |
+
utils.init_logging(os.path.join(work_dir, 'train.log'), cfgs.debug)
|
| 79 |
+
utils.backup_code(work_dir)
|
| 80 |
+
logging.info('Logs will be saved to %s' % work_dir)
|
| 81 |
+
|
| 82 |
+
else:
|
| 83 |
+
# disable logging on other workers
|
| 84 |
+
logging.root.disabled = True
|
| 85 |
+
work_dir = '/tmp'
|
| 86 |
+
|
| 87 |
+
logging.info('Using GPU: %s' % torch.cuda.get_device_name(local_rank))
|
| 88 |
+
torch.cuda.set_device(local_rank)
|
| 89 |
+
|
| 90 |
+
if world_size > 1:
|
| 91 |
+
logging.info('Initializing DDP with %d GPUs...' % world_size)
|
| 92 |
+
dist.init_process_group('nccl', init_method='env://')
|
| 93 |
+
|
| 94 |
+
logging.info('Setting random seed: 0')
|
| 95 |
+
set_random_seed(0, deterministic=True)
|
| 96 |
+
|
| 97 |
+
logging.info('Loading training set from %s' % cfgs.dataset_root)
|
| 98 |
+
train_dataset = build_dataset(cfgs.data.train)
|
| 99 |
+
train_loader = build_dataloader(
|
| 100 |
+
train_dataset,
|
| 101 |
+
samples_per_gpu=cfgs.batch_size // world_size,
|
| 102 |
+
workers_per_gpu=cfgs.data.workers_per_gpu,
|
| 103 |
+
num_gpus=world_size,
|
| 104 |
+
dist=world_size > 1,
|
| 105 |
+
shuffle=True,
|
| 106 |
+
seed=0,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
logging.info('Loading validation set from %s' % cfgs.dataset_root)
|
| 110 |
+
val_dataset = build_dataset(cfgs.data.val)
|
| 111 |
+
val_loader = build_dataloader(
|
| 112 |
+
val_dataset,
|
| 113 |
+
samples_per_gpu=1,
|
| 114 |
+
workers_per_gpu=cfgs.data.workers_per_gpu,
|
| 115 |
+
num_gpus=world_size,
|
| 116 |
+
dist=world_size > 1,
|
| 117 |
+
shuffle=False
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
logging.info('Creating model: %s' % cfgs.model.type)
|
| 121 |
+
model = build_model(cfgs.model)
|
| 122 |
+
model.init_weights()
|
| 123 |
+
model.cuda()
|
| 124 |
+
model.train()
|
| 125 |
+
|
| 126 |
+
n_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
|
| 127 |
+
logging.info('Trainable parameters: %d (%.1fM)' % (n_params, n_params / 1e6))
|
| 128 |
+
logging.info('Batch size per GPU: %d' % (cfgs.batch_size // world_size))
|
| 129 |
+
|
| 130 |
+
if world_size > 1:
|
| 131 |
+
model = MMDistributedDataParallel(model, [local_rank], broadcast_buffers=False)
|
| 132 |
+
else:
|
| 133 |
+
model = MMDataParallel(model, [0])
|
| 134 |
+
|
| 135 |
+
logging.info('Creating optimizer: %s' % cfgs.optimizer.type)
|
| 136 |
+
optimizer = build_optimizer(model, cfgs.optimizer)
|
| 137 |
+
|
| 138 |
+
runner = EpochBasedRunner(
|
| 139 |
+
model,
|
| 140 |
+
optimizer=optimizer,
|
| 141 |
+
work_dir=work_dir,
|
| 142 |
+
logger=logging.root,
|
| 143 |
+
max_epochs=cfgs.total_epochs,
|
| 144 |
+
meta=dict(),
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
runner.register_lr_hook(cfgs.lr_config)
|
| 148 |
+
runner.register_optimizer_hook(cfgs.optimizer_config)
|
| 149 |
+
runner.register_checkpoint_hook(cfgs.checkpoint_config)
|
| 150 |
+
runner.register_logger_hooks(cfgs.log_config)
|
| 151 |
+
runner.register_timer_hook(dict(type='IterTimerHook'))
|
| 152 |
+
runner.register_custom_hooks(dict(type='DistSamplerSeedHook'))
|
| 153 |
+
|
| 154 |
+
if cfgs.eval_config['interval'] > 0:
|
| 155 |
+
if world_size > 1:
|
| 156 |
+
runner.register_hook(DistEvalHook(val_loader, interval=cfgs.eval_config['interval'], gpu_collect=True))
|
| 157 |
+
else:
|
| 158 |
+
runner.register_hook(EvalHook(val_loader, interval=cfgs.eval_config['interval']))
|
| 159 |
+
|
| 160 |
+
if cfgs.resume_from is not None:
|
| 161 |
+
logging.info('Resuming from %s' % cfgs.resume_from)
|
| 162 |
+
runner.resume(cfgs.resume_from)
|
| 163 |
+
|
| 164 |
+
elif cfgs.load_from is not None:
|
| 165 |
+
logging.info('Loading checkpoint from %s' % cfgs.load_from)
|
| 166 |
+
if cfgs.revise_keys is not None:
|
| 167 |
+
load_checkpoint(
|
| 168 |
+
model, cfgs.load_from, map_location='cpu',
|
| 169 |
+
revise_keys=cfgs.revise_keys
|
| 170 |
+
)
|
| 171 |
+
else:
|
| 172 |
+
load_checkpoint(
|
| 173 |
+
model, cfgs.load_from, map_location='cpu',
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
runner.run([train_loader], [('train', 1)])
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
if __name__ == '__main__':
|
| 180 |
+
main()
|
utils.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import glob
|
| 4 |
+
import torch
|
| 5 |
+
import shutil
|
| 6 |
+
import logging
|
| 7 |
+
import datetime
|
| 8 |
+
from mmcv.runner.hooks import HOOKS
|
| 9 |
+
from mmcv.runner.hooks.logger import LoggerHook, TextLoggerHook
|
| 10 |
+
from mmcv.runner.dist_utils import master_only
|
| 11 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def init_logging(filename=None, debug=False):
|
| 15 |
+
logging.root = logging.RootLogger('DEBUG' if debug else 'INFO')
|
| 16 |
+
formatter = logging.Formatter('[%(asctime)s][%(levelname)s] - %(message)s')
|
| 17 |
+
|
| 18 |
+
stream_handler = logging.StreamHandler(sys.stdout)
|
| 19 |
+
stream_handler.setFormatter(formatter)
|
| 20 |
+
logging.root.addHandler(stream_handler)
|
| 21 |
+
|
| 22 |
+
if filename is not None:
|
| 23 |
+
file_handler = logging.FileHandler(filename)
|
| 24 |
+
file_handler.setFormatter(formatter)
|
| 25 |
+
logging.root.addHandler(file_handler)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def backup_code(work_dir, verbose=False):
|
| 29 |
+
base_dir = os.path.dirname(os.path.abspath(__file__))
|
| 30 |
+
for pattern in ['*.py', 'configs/*.py', 'models/*.py', 'loaders/*.py', 'loaders/pipelines/*.py']:
|
| 31 |
+
for file in glob.glob(pattern):
|
| 32 |
+
src = os.path.join(base_dir, file)
|
| 33 |
+
dst = os.path.join(work_dir, 'backup', os.path.dirname(file))
|
| 34 |
+
|
| 35 |
+
if verbose:
|
| 36 |
+
logging.info('Copying %s -> %s' % (os.path.relpath(src), os.path.relpath(dst)))
|
| 37 |
+
|
| 38 |
+
os.makedirs(dst, exist_ok=True)
|
| 39 |
+
shutil.copy2(src, dst)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@HOOKS.register_module()
|
| 43 |
+
class MyTextLoggerHook(TextLoggerHook):
|
| 44 |
+
def _log_info(self, log_dict, runner):
|
| 45 |
+
# print exp name for users to distinguish experiments
|
| 46 |
+
# at every ``interval_exp_name`` iterations and the end of each epoch
|
| 47 |
+
if runner.meta is not None and 'exp_name' in runner.meta:
|
| 48 |
+
if (self.every_n_iters(runner, self.interval_exp_name)) or (
|
| 49 |
+
self.by_epoch and self.end_of_epoch(runner)):
|
| 50 |
+
exp_info = f'Exp name: {runner.meta["exp_name"]}'
|
| 51 |
+
runner.logger.info(exp_info)
|
| 52 |
+
|
| 53 |
+
# by epoch: Epoch [4][100/1000]
|
| 54 |
+
# by iter: Iter [100/100000]
|
| 55 |
+
if self.by_epoch:
|
| 56 |
+
log_str = f'Epoch [{log_dict["epoch"]}/{runner.max_epochs}]' \
|
| 57 |
+
f'[{log_dict["iter"]}/{len(runner.data_loader)}] '
|
| 58 |
+
else:
|
| 59 |
+
log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}] '
|
| 60 |
+
|
| 61 |
+
log_str += 'loss: %.2f, ' % log_dict['loss']
|
| 62 |
+
|
| 63 |
+
if 'time' in log_dict.keys():
|
| 64 |
+
# MOD: skip the first iteration since it's not accurate
|
| 65 |
+
if runner.iter == self.start_iter:
|
| 66 |
+
time_sec_avg = log_dict['time']
|
| 67 |
+
else:
|
| 68 |
+
self.time_sec_tot += (log_dict['time'] * self.interval)
|
| 69 |
+
time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter)
|
| 70 |
+
|
| 71 |
+
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
|
| 72 |
+
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
| 73 |
+
log_str += f'eta: {eta_str}, '
|
| 74 |
+
log_str += f'time: {log_dict["time"]:.2f}s, ' \
|
| 75 |
+
f'data: {log_dict["data_time"] * 1000:.0f}ms, '
|
| 76 |
+
# statistic memory
|
| 77 |
+
if torch.cuda.is_available():
|
| 78 |
+
log_str += f'mem: {log_dict["memory"]}M'
|
| 79 |
+
|
| 80 |
+
runner.logger.info(log_str)
|
| 81 |
+
|
| 82 |
+
def log(self, runner):
|
| 83 |
+
if 'eval_iter_num' in runner.log_buffer.output:
|
| 84 |
+
# this doesn't modify runner.iter and is regardless of by_epoch
|
| 85 |
+
cur_iter = runner.log_buffer.output.pop('eval_iter_num')
|
| 86 |
+
else:
|
| 87 |
+
cur_iter = self.get_iter(runner, inner_iter=True)
|
| 88 |
+
|
| 89 |
+
log_dict = {
|
| 90 |
+
'mode': self.get_mode(runner),
|
| 91 |
+
'epoch': self.get_epoch(runner),
|
| 92 |
+
'iter': cur_iter
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
# only record lr of the first param group
|
| 96 |
+
cur_lr = runner.current_lr()
|
| 97 |
+
if isinstance(cur_lr, list):
|
| 98 |
+
log_dict['lr'] = cur_lr[0]
|
| 99 |
+
else:
|
| 100 |
+
assert isinstance(cur_lr, dict)
|
| 101 |
+
log_dict['lr'] = {}
|
| 102 |
+
for k, lr_ in cur_lr.items():
|
| 103 |
+
assert isinstance(lr_, list)
|
| 104 |
+
log_dict['lr'].update({k: lr_[0]})
|
| 105 |
+
|
| 106 |
+
if 'time' in runner.log_buffer.output:
|
| 107 |
+
# statistic memory
|
| 108 |
+
if torch.cuda.is_available():
|
| 109 |
+
log_dict['memory'] = self._get_max_memory(runner)
|
| 110 |
+
|
| 111 |
+
log_dict = dict(log_dict, **runner.log_buffer.output)
|
| 112 |
+
|
| 113 |
+
# MOD: disable writing to files
|
| 114 |
+
# self._dump_log(log_dict, runner)
|
| 115 |
+
self._log_info(log_dict, runner)
|
| 116 |
+
|
| 117 |
+
return log_dict
|
| 118 |
+
|
| 119 |
+
def after_train_epoch(self, runner):
|
| 120 |
+
if runner.log_buffer.ready:
|
| 121 |
+
metrics = self.get_loggable_tags(runner)
|
| 122 |
+
runner.logger.info('--- Evaluation Results ---')
|
| 123 |
+
runner.logger.info('mAP: %.4f' % metrics['val/pts_bbox_NuScenes/mAP'])
|
| 124 |
+
runner.logger.info('mATE: %.4f' % metrics['val/pts_bbox_NuScenes/mATE'])
|
| 125 |
+
runner.logger.info('mASE: %.4f' % metrics['val/pts_bbox_NuScenes/mASE'])
|
| 126 |
+
runner.logger.info('mAOE: %.4f' % metrics['val/pts_bbox_NuScenes/mAOE'])
|
| 127 |
+
runner.logger.info('mAVE: %.4f' % metrics['val/pts_bbox_NuScenes/mAVE'])
|
| 128 |
+
runner.logger.info('mAAE: %.4f' % metrics['val/pts_bbox_NuScenes/mAAE'])
|
| 129 |
+
runner.logger.info('NDS: %.4f' % metrics['val/pts_bbox_NuScenes/NDS'])
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@HOOKS.register_module()
|
| 133 |
+
class MyTensorboardLoggerHook(LoggerHook):
|
| 134 |
+
def __init__(self, log_dir=None, interval=10, ignore_last=True, reset_flag=False, by_epoch=True):
|
| 135 |
+
super(MyTensorboardLoggerHook, self).__init__(
|
| 136 |
+
interval, ignore_last, reset_flag, by_epoch)
|
| 137 |
+
self.log_dir = log_dir
|
| 138 |
+
|
| 139 |
+
@master_only
|
| 140 |
+
def before_run(self, runner):
|
| 141 |
+
super(MyTensorboardLoggerHook, self).before_run(runner)
|
| 142 |
+
if self.log_dir is None:
|
| 143 |
+
self.log_dir = runner.work_dir
|
| 144 |
+
self.writer = SummaryWriter(self.log_dir)
|
| 145 |
+
|
| 146 |
+
@master_only
|
| 147 |
+
def log(self, runner):
|
| 148 |
+
tags = self.get_loggable_tags(runner)
|
| 149 |
+
|
| 150 |
+
for key, value in tags.items():
|
| 151 |
+
# MOD: merge into the 'train' group
|
| 152 |
+
if key == 'learning_rate':
|
| 153 |
+
key = 'train/learning_rate'
|
| 154 |
+
|
| 155 |
+
# MOD: skip momentum
|
| 156 |
+
ignore = False
|
| 157 |
+
if key == 'momentum':
|
| 158 |
+
ignore = True
|
| 159 |
+
|
| 160 |
+
# MOD: skip intermediate losses
|
| 161 |
+
for i in range(5):
|
| 162 |
+
if key[:13] == 'train/d%d.loss' % i:
|
| 163 |
+
ignore = True
|
| 164 |
+
|
| 165 |
+
if key[:3] == 'val':
|
| 166 |
+
metric_name = key[22:]
|
| 167 |
+
if metric_name in ['mAP', 'mATE', 'mASE', 'mAOE', 'mAVE', 'mAAE', 'NDS']:
|
| 168 |
+
key = 'val/' + metric_name
|
| 169 |
+
else:
|
| 170 |
+
ignore = True
|
| 171 |
+
|
| 172 |
+
if self.get_mode(runner) == 'train' and key[:5] != 'train':
|
| 173 |
+
ignore = True
|
| 174 |
+
|
| 175 |
+
if self.get_mode(runner) != 'train' and key[:3] != 'val':
|
| 176 |
+
ignore = True
|
| 177 |
+
|
| 178 |
+
if ignore:
|
| 179 |
+
continue
|
| 180 |
+
|
| 181 |
+
if key[:5] == 'train':
|
| 182 |
+
self.writer.add_scalar(key, value, self.get_iter(runner))
|
| 183 |
+
elif key[:3] == 'val':
|
| 184 |
+
self.writer.add_scalar(key, value, self.get_epoch(runner))
|
| 185 |
+
|
| 186 |
+
@master_only
|
| 187 |
+
def after_run(self, runner):
|
| 188 |
+
self.writer.close()
|
val.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import utils
|
| 3 |
+
import logging
|
| 4 |
+
import argparse
|
| 5 |
+
import importlib
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
import torch.backends.cudnn as cudnn
|
| 10 |
+
from mmcv import Config
|
| 11 |
+
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
| 12 |
+
from mmcv.runner import load_checkpoint
|
| 13 |
+
from mmdet.apis import set_random_seed, multi_gpu_test, single_gpu_test
|
| 14 |
+
from mmdet3d.datasets import build_dataset, build_dataloader
|
| 15 |
+
from mmdet3d.models import build_model
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def evaluate(dataset, results, epoch):
|
| 19 |
+
metrics = dataset.evaluate(results, jsonfile_prefix=None)
|
| 20 |
+
|
| 21 |
+
mAP = metrics['pts_bbox_NuScenes/mAP']
|
| 22 |
+
mATE = metrics['pts_bbox_NuScenes/mATE']
|
| 23 |
+
mASE = metrics['pts_bbox_NuScenes/mASE']
|
| 24 |
+
mAOE = metrics['pts_bbox_NuScenes/mAOE']
|
| 25 |
+
mAVE = metrics['pts_bbox_NuScenes/mAVE']
|
| 26 |
+
mAAE = metrics['pts_bbox_NuScenes/mAAE']
|
| 27 |
+
NDS = metrics['pts_bbox_NuScenes/NDS']
|
| 28 |
+
|
| 29 |
+
logging.info('--- Evaluation Results (Epoch %d) ---' % epoch)
|
| 30 |
+
logging.info('mAP: %.4f' % metrics['pts_bbox_NuScenes/mAP'])
|
| 31 |
+
logging.info('mATE: %.4f' % metrics['pts_bbox_NuScenes/mATE'])
|
| 32 |
+
logging.info('mASE: %.4f' % metrics['pts_bbox_NuScenes/mASE'])
|
| 33 |
+
logging.info('mAOE: %.4f' % metrics['pts_bbox_NuScenes/mAOE'])
|
| 34 |
+
logging.info('mAVE: %.4f' % metrics['pts_bbox_NuScenes/mAVE'])
|
| 35 |
+
logging.info('mAAE: %.4f' % metrics['pts_bbox_NuScenes/mAAE'])
|
| 36 |
+
logging.info('NDS: %.4f' % metrics['pts_bbox_NuScenes/NDS'])
|
| 37 |
+
|
| 38 |
+
return {
|
| 39 |
+
'mAP': mAP,
|
| 40 |
+
'mATE': mATE,
|
| 41 |
+
'mASE': mASE,
|
| 42 |
+
'mAOE': mAOE,
|
| 43 |
+
'mAVE': mAVE,
|
| 44 |
+
'mAAE': mAAE,
|
| 45 |
+
'NDS': NDS,
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def main():
|
| 50 |
+
parser = argparse.ArgumentParser(description='Validate a detector')
|
| 51 |
+
parser.add_argument('--config', required=True)
|
| 52 |
+
parser.add_argument('--weights', required=True)
|
| 53 |
+
parser.add_argument('--local_rank', type=int, default=0)
|
| 54 |
+
parser.add_argument('--world_size', type=int, default=1)
|
| 55 |
+
parser.add_argument('--batch_size', type=int, default=1)
|
| 56 |
+
args = parser.parse_args()
|
| 57 |
+
|
| 58 |
+
# parse configs
|
| 59 |
+
cfgs = Config.fromfile(args.config)
|
| 60 |
+
|
| 61 |
+
# register custom module
|
| 62 |
+
importlib.import_module('models')
|
| 63 |
+
importlib.import_module('loaders')
|
| 64 |
+
|
| 65 |
+
# MMCV, please shut up
|
| 66 |
+
from mmcv.utils.logging import logger_initialized
|
| 67 |
+
logger_initialized['root'] = logging.Logger(__name__, logging.WARNING)
|
| 68 |
+
logger_initialized['mmcv'] = logging.Logger(__name__, logging.WARNING)
|
| 69 |
+
|
| 70 |
+
# you need GPUs
|
| 71 |
+
assert torch.cuda.is_available()
|
| 72 |
+
|
| 73 |
+
# determine local_rank and world_size
|
| 74 |
+
if 'LOCAL_RANK' not in os.environ:
|
| 75 |
+
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
| 76 |
+
|
| 77 |
+
if 'WORLD_SIZE' not in os.environ:
|
| 78 |
+
os.environ['WORLD_SIZE'] = str(args.world_size)
|
| 79 |
+
|
| 80 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
| 81 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
| 82 |
+
|
| 83 |
+
if local_rank == 0:
|
| 84 |
+
utils.init_logging(None, cfgs.debug)
|
| 85 |
+
else:
|
| 86 |
+
logging.root.disabled = True
|
| 87 |
+
|
| 88 |
+
logging.info('Using GPU: %s' % torch.cuda.get_device_name(local_rank))
|
| 89 |
+
torch.cuda.set_device(local_rank)
|
| 90 |
+
|
| 91 |
+
if world_size > 1:
|
| 92 |
+
logging.info('Initializing DDP with %d GPUs...' % world_size)
|
| 93 |
+
dist.init_process_group('nccl', init_method='env://')
|
| 94 |
+
|
| 95 |
+
logging.info('Setting random seed: 0')
|
| 96 |
+
set_random_seed(0, deterministic=True)
|
| 97 |
+
cudnn.benchmark = True
|
| 98 |
+
|
| 99 |
+
logging.info('Loading validation set from %s' % cfgs.data.val.data_root)
|
| 100 |
+
val_dataset = build_dataset(cfgs.data.val)
|
| 101 |
+
val_loader = build_dataloader(
|
| 102 |
+
val_dataset,
|
| 103 |
+
samples_per_gpu=args.batch_size,
|
| 104 |
+
workers_per_gpu=cfgs.data.workers_per_gpu,
|
| 105 |
+
num_gpus=world_size,
|
| 106 |
+
dist=world_size > 1,
|
| 107 |
+
shuffle=False,
|
| 108 |
+
seed=0,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
logging.info('Creating model: %s' % cfgs.model.type)
|
| 112 |
+
model = build_model(cfgs.model)
|
| 113 |
+
model.cuda()
|
| 114 |
+
|
| 115 |
+
if world_size > 1:
|
| 116 |
+
model = MMDistributedDataParallel(model, [local_rank], broadcast_buffers=False)
|
| 117 |
+
else:
|
| 118 |
+
model = MMDataParallel(model, [0])
|
| 119 |
+
|
| 120 |
+
if os.path.isfile(args.weights):
|
| 121 |
+
logging.info('Loading checkpoint from %s' % args.weights)
|
| 122 |
+
load_checkpoint(
|
| 123 |
+
model, args.weights, map_location='cuda', strict=True,
|
| 124 |
+
logger=logging.Logger(__name__, logging.ERROR)
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
if world_size > 1:
|
| 128 |
+
results = multi_gpu_test(model, val_loader, gpu_collect=True)
|
| 129 |
+
else:
|
| 130 |
+
results = single_gpu_test(model, val_loader)
|
| 131 |
+
|
| 132 |
+
if local_rank == 0:
|
| 133 |
+
evaluate(val_dataset, results, -1)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if __name__ == '__main__':
|
| 137 |
+
main()
|