Haisong Liu commited on
Release model: vit_eva02_1600x640_trainval_future (#46)
Browse files- README.md +1 -0
- configs/r50_nuimg_704x256.py +1 -1
- configs/vit_eva02_1600x640_trainval_future.py +112 -0
- loaders/pipelines/loading.py +135 -0
- models/backbones/__init__.py +2 -1
- models/backbones/eva02/__init__.py +1 -0
- models/backbones/eva02/backbone.py +94 -0
- models/backbones/eva02/batch_norm.py +214 -0
- models/backbones/eva02/blocks.py +111 -0
- models/backbones/eva02/drop.py +37 -0
- models/backbones/eva02/fpn.py +50 -0
- models/backbones/eva02/main.py +93 -0
- models/backbones/eva02/utils.py +361 -0
- models/backbones/eva02/vit.py +609 -0
- models/backbones/eva02/wrappers.py +148 -0
- models/sparsebev.py +4 -4
- val.py +1 -0
README.md
CHANGED
|
@@ -28,6 +28,7 @@ This is the official PyTorch implementation for our ICCV 2023 paper:
|
|
| 28 |
| [r50_nuimg_704x256_400q_36ep](configs/r50_nuimg_704x256_400q_36ep.py) | [nuImg](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/nuimages_semseg/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim_20201009_124951-40963960.pth) | 28h (8x2080Ti) | 55.8 | - | 23.5 | [gdrive](https://drive.google.com/file/d/1C_Vn3iiSnSW1Dw1r0DkjJMwvHC5Y3zTN/view) |
|
| 29 |
| [r101_nuimg_1408x512](configs/r101_nuimg_1408x512.py) | [nuImg](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/nuimages_semseg/cascade_mask_rcnn_r101_fpn_1x_nuim/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth) | 2d8h (8xV100) | 59.2 | - | 6.5 | [gdrive](https://drive.google.com/file/d/1dKu5cR1fuo-O0ynyBh-RCPtHrgut29mN/view) |
|
| 30 |
| [vov99_dd3d_1600x640_trainval_future](configs/vov99_dd3d_1600x640_trainval_future.py) | [DD3D](https://drive.google.com/file/d/1gQkhWERCzAosBwG5bh2BKkt1k0TJZt-A/view) | 4d1h (8xA100) | 84.9 | 67.5 | - | [gdrive](https://drive.google.com/file/d/1TL0QoCiWD5uq8PCAWWE3A-g73ibK1R0S/view) |
|
|
|
|
| 31 |
|
| 32 |
* We use `r50_nuimg_704x256` for ablation studies and `r50_nuimg_704x256_400q_36ep` for comparison with others.
|
| 33 |
* We recommend using `r50_nuimg_704x256` to validate new ideas since it trains faster and the result is more stable.
|
|
|
|
| 28 |
| [r50_nuimg_704x256_400q_36ep](configs/r50_nuimg_704x256_400q_36ep.py) | [nuImg](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/nuimages_semseg/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim_20201009_124951-40963960.pth) | 28h (8x2080Ti) | 55.8 | - | 23.5 | [gdrive](https://drive.google.com/file/d/1C_Vn3iiSnSW1Dw1r0DkjJMwvHC5Y3zTN/view) |
|
| 29 |
| [r101_nuimg_1408x512](configs/r101_nuimg_1408x512.py) | [nuImg](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/nuimages_semseg/cascade_mask_rcnn_r101_fpn_1x_nuim/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth) | 2d8h (8xV100) | 59.2 | - | 6.5 | [gdrive](https://drive.google.com/file/d/1dKu5cR1fuo-O0ynyBh-RCPtHrgut29mN/view) |
|
| 30 |
| [vov99_dd3d_1600x640_trainval_future](configs/vov99_dd3d_1600x640_trainval_future.py) | [DD3D](https://drive.google.com/file/d/1gQkhWERCzAosBwG5bh2BKkt1k0TJZt-A/view) | 4d1h (8xA100) | 84.9 | 67.5 | - | [gdrive](https://drive.google.com/file/d/1TL0QoCiWD5uq8PCAWWE3A-g73ibK1R0S/view) |
|
| 31 |
+
| [vit_eva02_1600x640_trainval_future](configs/vit_eva02_1600x640_trainval_future.py) | [EVA02](https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/det/eva02_L_coco_seg_sys_o365.pth) | 11d (8xA100) | 85.3 | 70.2 | - | [gdrive](https://drive.google.com/file/d/1cx7h6PUqiaVWPixpcuB9AhsX3Sx4n0q_/view) |
|
| 32 |
|
| 33 |
* We use `r50_nuimg_704x256` for ablation studies and `r50_nuimg_704x256_400q_36ep` for comparison with others.
|
| 34 |
* We recommend using `r50_nuimg_704x256` to validate new ideas since it trains faster and the result is more stable.
|
configs/r50_nuimg_704x256.py
CHANGED
|
@@ -54,7 +54,7 @@ model = dict(
|
|
| 54 |
img_color_aug=True, # Move some augmentations to GPU
|
| 55 |
img_norm_cfg=img_norm_cfg,
|
| 56 |
img_pad_cfg=dict(size_divisor=32)),
|
| 57 |
-
stop_prev_grad=
|
| 58 |
img_backbone=img_backbone,
|
| 59 |
img_neck=img_neck,
|
| 60 |
pts_bbox_head=dict(
|
|
|
|
| 54 |
img_color_aug=True, # Move some augmentations to GPU
|
| 55 |
img_norm_cfg=img_norm_cfg,
|
| 56 |
img_pad_cfg=dict(size_divisor=32)),
|
| 57 |
+
stop_prev_grad=0,
|
| 58 |
img_backbone=img_backbone,
|
| 59 |
img_neck=img_neck,
|
| 60 |
pts_bbox_head=dict(
|
configs/vit_eva02_1600x640_trainval_future.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_base_ = ['./r50_nuimg_704x256.py']
|
| 2 |
+
|
| 3 |
+
# For nuScenes we usually do 10-class detection
|
| 4 |
+
class_names = [
|
| 5 |
+
'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
|
| 6 |
+
'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
|
| 7 |
+
]
|
| 8 |
+
|
| 9 |
+
# If point cloud range is changed, the models should also change their point
|
| 10 |
+
# cloud range accordingly
|
| 11 |
+
point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
|
| 12 |
+
voxel_size = [0.2, 0.2, 8]
|
| 13 |
+
|
| 14 |
+
img_backbone = dict(
|
| 15 |
+
_delete_=True,
|
| 16 |
+
type='EVA02',
|
| 17 |
+
img_size=1536,
|
| 18 |
+
real_img_size=(640, 1600),
|
| 19 |
+
patch_size=16,
|
| 20 |
+
in_chans=3,
|
| 21 |
+
embed_dim=1024,
|
| 22 |
+
depth=24,
|
| 23 |
+
num_heads=16,
|
| 24 |
+
mlp_ratio=4*2/3,
|
| 25 |
+
qkv_bias=True,
|
| 26 |
+
drop_path_rate=0.3,
|
| 27 |
+
use_abs_pos=True,
|
| 28 |
+
window_size=16,
|
| 29 |
+
window_block_indexes=(
|
| 30 |
+
list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + list(range(12, 14)) + list(range(15, 17)) + list(range(18, 20)) + list(range(21, 23))
|
| 31 |
+
),
|
| 32 |
+
residual_block_indexes=(),
|
| 33 |
+
use_act_checkpoint=True,
|
| 34 |
+
# args for simple FPN
|
| 35 |
+
fpn_out_channels=256,
|
| 36 |
+
fpn_scale_factors=(4.0, 2.0, 1.0, 0.5),
|
| 37 |
+
fpn_top_block=True,
|
| 38 |
+
fpn_norm="LN",
|
| 39 |
+
fpn_square_pad=1600,
|
| 40 |
+
pretrained='pretrain/eva02_L_coco_seg_sys_o365.pth',
|
| 41 |
+
frozen_blocks=3,
|
| 42 |
+
)
|
| 43 |
+
img_norm_cfg = dict(
|
| 44 |
+
mean=[123.675, 116.280, 103.530],
|
| 45 |
+
std=[58.395, 57.120, 57.375],
|
| 46 |
+
to_rgb=True
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
model = dict(
|
| 50 |
+
img_backbone=img_backbone,
|
| 51 |
+
img_neck=None,
|
| 52 |
+
stop_prev_grad=4,
|
| 53 |
+
pts_bbox_head=dict(
|
| 54 |
+
num_query=1600,
|
| 55 |
+
transformer=dict(
|
| 56 |
+
num_levels=5,
|
| 57 |
+
num_points=8,
|
| 58 |
+
num_frames=15))
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
ida_aug_conf = {
|
| 62 |
+
'resize_lim': (0.94, 1.25),
|
| 63 |
+
'final_dim': (640, 1600),
|
| 64 |
+
'bot_pct_lim': (0.0, 0.0),
|
| 65 |
+
'rot_lim': (0.0, 0.0),
|
| 66 |
+
'H': 900, 'W': 1600,
|
| 67 |
+
'rand_flip': True,
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
train_pipeline = [
|
| 71 |
+
dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),
|
| 72 |
+
dict(type='LoadMultiViewImageFromMultiSweepsFutureInterleave', prev_sweeps_num=7, next_sweeps_num=7),
|
| 73 |
+
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True, with_attr_label=False),
|
| 74 |
+
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
|
| 75 |
+
dict(type='ObjectNameFilter', classes=class_names),
|
| 76 |
+
dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=True),
|
| 77 |
+
dict(type='GlobalRotScaleTransImage', rot_range=[-0.3925, 0.3925], scale_ratio_range=[0.95, 1.05]),
|
| 78 |
+
dict(type='DefaultFormatBundle3D', class_names=class_names),
|
| 79 |
+
dict(type='Collect3D', keys=['gt_bboxes_3d', 'gt_labels_3d', 'img'], meta_keys=(
|
| 80 |
+
'filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp'))
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
test_pipeline = [
|
| 84 |
+
dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),
|
| 85 |
+
dict(type='LoadMultiViewImageFromMultiSweepsFutureInterleave', prev_sweeps_num=7, next_sweeps_num=7, test_mode=True),
|
| 86 |
+
dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=False),
|
| 87 |
+
dict(
|
| 88 |
+
type='MultiScaleFlipAug3D',
|
| 89 |
+
img_scale=(1600, 900),
|
| 90 |
+
pts_scale_ratio=1,
|
| 91 |
+
flip=False,
|
| 92 |
+
transforms=[
|
| 93 |
+
dict(type='DefaultFormatBundle3D', class_names=class_names, with_label=False),
|
| 94 |
+
dict(type='Collect3D', keys=['img'], meta_keys=(
|
| 95 |
+
'filename', 'box_type_3d', 'ori_shape', 'img_shape', 'pad_shape',
|
| 96 |
+
'lidar2img', 'img_timestamp'))
|
| 97 |
+
])
|
| 98 |
+
]
|
| 99 |
+
|
| 100 |
+
data = dict(
|
| 101 |
+
train=dict(
|
| 102 |
+
ann_file=['data/nuscenes/nuscenes_infos_train_sweep.pkl',
|
| 103 |
+
'data/nuscenes/nuscenes_infos_val_sweep.pkl'],
|
| 104 |
+
pipeline=train_pipeline),
|
| 105 |
+
val=dict(
|
| 106 |
+
ann_file='data/nuscenes/nuscenes_infos_val_sweep.pkl', # use nuscenes_infos_test_sweep.pkl for submission
|
| 107 |
+
pipeline=test_pipeline),
|
| 108 |
+
test=dict(pipeline=test_pipeline)
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
load_from = None
|
| 112 |
+
revise_keys = None
|
loaders/pipelines/loading.py
CHANGED
|
@@ -255,3 +255,138 @@ class LoadMultiViewImageFromMultiSweepsFuture(object):
|
|
| 255 |
))
|
| 256 |
|
| 257 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
))
|
| 256 |
|
| 257 |
return results
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
'''
|
| 261 |
+
This func loads previous and future frames in interleaved order,
|
| 262 |
+
e.g. curr, prev1, next1, prev2, next2, prev3, next3...
|
| 263 |
+
'''
|
| 264 |
+
@PIPELINES.register_module()
|
| 265 |
+
class LoadMultiViewImageFromMultiSweepsFutureInterleave(object):
|
| 266 |
+
def __init__(self,
|
| 267 |
+
prev_sweeps_num=5,
|
| 268 |
+
next_sweeps_num=5,
|
| 269 |
+
color_type='color',
|
| 270 |
+
test_mode=False):
|
| 271 |
+
self.prev_sweeps_num = prev_sweeps_num
|
| 272 |
+
self.next_sweeps_num = next_sweeps_num
|
| 273 |
+
self.color_type = color_type
|
| 274 |
+
self.test_mode = test_mode
|
| 275 |
+
|
| 276 |
+
assert prev_sweeps_num == next_sweeps_num
|
| 277 |
+
|
| 278 |
+
self.train_interval = [4, 8]
|
| 279 |
+
self.test_interval = 6
|
| 280 |
+
|
| 281 |
+
try:
|
| 282 |
+
mmcv.use_backend('turbojpeg')
|
| 283 |
+
except ImportError:
|
| 284 |
+
mmcv.use_backend('cv2')
|
| 285 |
+
|
| 286 |
+
def __call__(self, results):
|
| 287 |
+
if self.prev_sweeps_num == 0 and self.next_sweeps_num == 0:
|
| 288 |
+
return results
|
| 289 |
+
|
| 290 |
+
cam_types = [
|
| 291 |
+
'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT',
|
| 292 |
+
'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT'
|
| 293 |
+
]
|
| 294 |
+
|
| 295 |
+
if self.test_mode:
|
| 296 |
+
interval = self.test_interval
|
| 297 |
+
else:
|
| 298 |
+
interval = np.random.randint(self.train_interval[0], self.train_interval[1] + 1)
|
| 299 |
+
|
| 300 |
+
results_prev = dict(
|
| 301 |
+
img=[],
|
| 302 |
+
img_timestamp=[],
|
| 303 |
+
filename=[],
|
| 304 |
+
lidar2img=[]
|
| 305 |
+
)
|
| 306 |
+
results_next = dict(
|
| 307 |
+
img=[],
|
| 308 |
+
img_timestamp=[],
|
| 309 |
+
filename=[],
|
| 310 |
+
lidar2img=[]
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
if len(results['sweeps']['prev']) == 0:
|
| 314 |
+
for _ in range(self.prev_sweeps_num):
|
| 315 |
+
for j in range(len(cam_types)):
|
| 316 |
+
results_prev['img'].append(results['img'][j])
|
| 317 |
+
results_prev['img_timestamp'].append(results['img_timestamp'][j])
|
| 318 |
+
results_prev['filename'].append(results['filename'][j])
|
| 319 |
+
results_prev['lidar2img'].append(np.copy(results['lidar2img'][j]))
|
| 320 |
+
else:
|
| 321 |
+
choices = [(k + 1) * interval - 1 for k in range(self.prev_sweeps_num)]
|
| 322 |
+
|
| 323 |
+
for idx in sorted(list(choices)):
|
| 324 |
+
sweep_idx = min(idx, len(results['sweeps']['prev']) - 1)
|
| 325 |
+
sweep = results['sweeps']['prev'][sweep_idx]
|
| 326 |
+
|
| 327 |
+
if len(sweep.keys()) < len(cam_types):
|
| 328 |
+
sweep = results['sweeps']['prev'][sweep_idx - 1]
|
| 329 |
+
|
| 330 |
+
for sensor in cam_types:
|
| 331 |
+
results_prev['img'].append(mmcv.imread(sweep[sensor]['data_path'], self.color_type))
|
| 332 |
+
results_prev['img_timestamp'].append(sweep[sensor]['timestamp'] / 1e6)
|
| 333 |
+
results_prev['filename'].append(os.path.relpath(sweep[sensor]['data_path']))
|
| 334 |
+
results_prev['lidar2img'].append(compose_lidar2img(
|
| 335 |
+
results['ego2global_translation'],
|
| 336 |
+
results['ego2global_rotation'],
|
| 337 |
+
results['lidar2ego_translation'],
|
| 338 |
+
results['lidar2ego_rotation'],
|
| 339 |
+
sweep[sensor]['sensor2global_translation'],
|
| 340 |
+
sweep[sensor]['sensor2global_rotation'],
|
| 341 |
+
sweep[sensor]['cam_intrinsic'],
|
| 342 |
+
))
|
| 343 |
+
|
| 344 |
+
if len(results['sweeps']['next']) == 0:
|
| 345 |
+
print(1, len(results_next['img']) )
|
| 346 |
+
for _ in range(self.next_sweeps_num):
|
| 347 |
+
for j in range(len(cam_types)):
|
| 348 |
+
results_next['img'].append(results['img'][j])
|
| 349 |
+
results_next['img_timestamp'].append(results['img_timestamp'][j])
|
| 350 |
+
results_next['filename'].append(results['filename'][j])
|
| 351 |
+
results_next['lidar2img'].append(np.copy(results['lidar2img'][j]))
|
| 352 |
+
else:
|
| 353 |
+
choices = [(k + 1) * interval - 1 for k in range(self.next_sweeps_num)]
|
| 354 |
+
|
| 355 |
+
for idx in sorted(list(choices)):
|
| 356 |
+
sweep_idx = min(idx, len(results['sweeps']['next']) - 1)
|
| 357 |
+
sweep = results['sweeps']['next'][sweep_idx]
|
| 358 |
+
|
| 359 |
+
if len(sweep.keys()) < len(cam_types):
|
| 360 |
+
sweep = results['sweeps']['next'][sweep_idx - 1]
|
| 361 |
+
|
| 362 |
+
for sensor in cam_types:
|
| 363 |
+
results_next['img'].append(mmcv.imread(sweep[sensor]['data_path'], self.color_type))
|
| 364 |
+
results_next['img_timestamp'].append(sweep[sensor]['timestamp'] / 1e6)
|
| 365 |
+
results_next['filename'].append(os.path.relpath(sweep[sensor]['data_path']))
|
| 366 |
+
results_next['lidar2img'].append(compose_lidar2img(
|
| 367 |
+
results['ego2global_translation'],
|
| 368 |
+
results['ego2global_rotation'],
|
| 369 |
+
results['lidar2ego_translation'],
|
| 370 |
+
results['lidar2ego_rotation'],
|
| 371 |
+
sweep[sensor]['sensor2global_translation'],
|
| 372 |
+
sweep[sensor]['sensor2global_rotation'],
|
| 373 |
+
sweep[sensor]['cam_intrinsic'],
|
| 374 |
+
))
|
| 375 |
+
|
| 376 |
+
assert len(results_prev['img']) % 6 == 0
|
| 377 |
+
assert len(results_next['img']) % 6 == 0
|
| 378 |
+
|
| 379 |
+
for i in range(len(results_prev['img']) // 6):
|
| 380 |
+
for j in range(6):
|
| 381 |
+
results['img'].append(results_prev['img'][i * 6 + j])
|
| 382 |
+
results['img_timestamp'].append(results_prev['img_timestamp'][i * 6 + j])
|
| 383 |
+
results['filename'].append(results_prev['filename'][i * 6 + j])
|
| 384 |
+
results['lidar2img'].append(results_prev['lidar2img'][i * 6 + j])
|
| 385 |
+
|
| 386 |
+
for j in range(6):
|
| 387 |
+
results['img'].append(results_next['img'][i * 6 + j])
|
| 388 |
+
results['img_timestamp'].append(results_next['img_timestamp'][i * 6 + j])
|
| 389 |
+
results['filename'].append(results_next['filename'][i * 6 + j])
|
| 390 |
+
results['lidar2img'].append(results_next['lidar2img'][i * 6 + j])
|
| 391 |
+
|
| 392 |
+
return results
|
models/backbones/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
from .vovnet import VoVNet
|
|
|
|
| 2 |
|
| 3 |
-
__all__ = ['VoVNet']
|
|
|
|
| 1 |
from .vovnet import VoVNet
|
| 2 |
+
from .eva02 import EVA02
|
| 3 |
|
| 4 |
+
__all__ = ['VoVNet', 'EVA02']
|
models/backbones/eva02/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .main import EVA02
|
models/backbones/eva02/backbone.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/shape_spec.py
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class ShapeSpec:
|
| 10 |
+
"""
|
| 11 |
+
A simple structure that contains basic shape specification about a tensor.
|
| 12 |
+
It is often used as the auxiliary inputs/outputs of models,
|
| 13 |
+
to complement the lack of shape inference ability among pytorch modules.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
channels: Optional[int] = None
|
| 17 |
+
height: Optional[int] = None
|
| 18 |
+
width: Optional[int] = None
|
| 19 |
+
stride: Optional[int] = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 23 |
+
from abc import ABCMeta, abstractmethod
|
| 24 |
+
from typing import Dict
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
__all__ = ["Backbone"]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Backbone(nn.Module, metaclass=ABCMeta):
|
| 32 |
+
"""
|
| 33 |
+
Abstract base class for network backbones.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self):
|
| 37 |
+
"""
|
| 38 |
+
The `__init__` method of any subclass can specify its own set of arguments.
|
| 39 |
+
"""
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
@abstractmethod
|
| 43 |
+
def forward(self):
|
| 44 |
+
"""
|
| 45 |
+
Subclasses must override this method, but adhere to the same return type.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor
|
| 49 |
+
"""
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def size_divisibility(self) -> int:
|
| 54 |
+
"""
|
| 55 |
+
Some backbones require the input height and width to be divisible by a
|
| 56 |
+
specific integer. This is typically true for encoder / decoder type networks
|
| 57 |
+
with lateral connection (e.g., FPN) for which feature maps need to match
|
| 58 |
+
dimension in the "bottom up" and "top down" paths. Set to 0 if no specific
|
| 59 |
+
input size divisibility is required.
|
| 60 |
+
"""
|
| 61 |
+
return 0
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def padding_constraints(self) -> Dict[str, int]:
|
| 65 |
+
"""
|
| 66 |
+
This property is a generalization of size_divisibility. Some backbones and training
|
| 67 |
+
recipes require specific padding constraints, such as enforcing divisibility by a specific
|
| 68 |
+
integer (e.g., FPN) or padding to a square (e.g., ViTDet with large-scale jitter
|
| 69 |
+
in :paper:vitdet). `padding_constraints` contains these optional items like:
|
| 70 |
+
{
|
| 71 |
+
"size_divisibility": int,
|
| 72 |
+
"square_size": int,
|
| 73 |
+
# Future options are possible
|
| 74 |
+
}
|
| 75 |
+
`size_divisibility` will read from here if presented and `square_size` indicates the
|
| 76 |
+
square padding size if `square_size` > 0.
|
| 77 |
+
|
| 78 |
+
TODO: use type of Dict[str, int] to avoid torchscipt issues. The type of padding_constraints
|
| 79 |
+
could be generalized as TypedDict (Python 3.8+) to support more types in the future.
|
| 80 |
+
"""
|
| 81 |
+
return {}
|
| 82 |
+
|
| 83 |
+
def output_shape(self):
|
| 84 |
+
"""
|
| 85 |
+
Returns:
|
| 86 |
+
dict[str->ShapeSpec]
|
| 87 |
+
"""
|
| 88 |
+
# this is a backward-compatible default
|
| 89 |
+
return {
|
| 90 |
+
name: ShapeSpec(
|
| 91 |
+
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
|
| 92 |
+
)
|
| 93 |
+
for name in self._out_features
|
| 94 |
+
}
|
models/backbones/eva02/batch_norm.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
from fvcore.nn.distributed import differentiable_all_reduce
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
from .wrappers import BatchNorm2d
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class FrozenBatchNorm2d(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
| 14 |
+
|
| 15 |
+
It contains non-trainable buffers called
|
| 16 |
+
"weight" and "bias", "running_mean", "running_var",
|
| 17 |
+
initialized to perform identity transformation.
|
| 18 |
+
|
| 19 |
+
The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
|
| 20 |
+
which are computed from the original four parameters of BN.
|
| 21 |
+
The affine transform `x * weight + bias` will perform the equivalent
|
| 22 |
+
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
|
| 23 |
+
When loading a backbone model from Caffe2, "running_mean" and "running_var"
|
| 24 |
+
will be left unchanged as identity transformation.
|
| 25 |
+
|
| 26 |
+
Other pre-trained backbone models may contain all 4 parameters.
|
| 27 |
+
|
| 28 |
+
The forward is implemented by `F.batch_norm(..., training=False)`.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
_version = 3
|
| 32 |
+
|
| 33 |
+
def __init__(self, num_features, eps=1e-5):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.num_features = num_features
|
| 36 |
+
self.eps = eps
|
| 37 |
+
self.register_buffer("weight", torch.ones(num_features))
|
| 38 |
+
self.register_buffer("bias", torch.zeros(num_features))
|
| 39 |
+
self.register_buffer("running_mean", torch.zeros(num_features))
|
| 40 |
+
self.register_buffer("running_var", torch.ones(num_features) - eps)
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
if x.requires_grad:
|
| 44 |
+
# When gradients are needed, F.batch_norm will use extra memory
|
| 45 |
+
# because its backward op computes gradients for weight/bias as well.
|
| 46 |
+
scale = self.weight * (self.running_var + self.eps).rsqrt()
|
| 47 |
+
bias = self.bias - self.running_mean * scale
|
| 48 |
+
scale = scale.reshape(1, -1, 1, 1)
|
| 49 |
+
bias = bias.reshape(1, -1, 1, 1)
|
| 50 |
+
out_dtype = x.dtype # may be half
|
| 51 |
+
return x * scale.to(out_dtype) + bias.to(out_dtype)
|
| 52 |
+
else:
|
| 53 |
+
# When gradients are not needed, F.batch_norm is a single fused op
|
| 54 |
+
# and provide more optimization opportunities.
|
| 55 |
+
return F.batch_norm(
|
| 56 |
+
x,
|
| 57 |
+
self.running_mean,
|
| 58 |
+
self.running_var,
|
| 59 |
+
self.weight,
|
| 60 |
+
self.bias,
|
| 61 |
+
training=False,
|
| 62 |
+
eps=self.eps,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def _load_from_state_dict(
|
| 66 |
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
| 67 |
+
):
|
| 68 |
+
version = local_metadata.get("version", None)
|
| 69 |
+
|
| 70 |
+
if version is None or version < 2:
|
| 71 |
+
# No running_mean/var in early versions
|
| 72 |
+
# This will silent the warnings
|
| 73 |
+
if prefix + "running_mean" not in state_dict:
|
| 74 |
+
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
|
| 75 |
+
if prefix + "running_var" not in state_dict:
|
| 76 |
+
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
|
| 77 |
+
|
| 78 |
+
super()._load_from_state_dict(
|
| 79 |
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def __repr__(self):
|
| 83 |
+
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
|
| 84 |
+
|
| 85 |
+
@classmethod
|
| 86 |
+
def convert_frozen_batchnorm(cls, module):
|
| 87 |
+
"""
|
| 88 |
+
Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
module (torch.nn.Module):
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
If module is BatchNorm/SyncBatchNorm, returns a new module.
|
| 95 |
+
Otherwise, in-place convert module and return it.
|
| 96 |
+
|
| 97 |
+
Similar to convert_sync_batchnorm in
|
| 98 |
+
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
|
| 99 |
+
"""
|
| 100 |
+
bn_module = nn.modules.batchnorm
|
| 101 |
+
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
|
| 102 |
+
res = module
|
| 103 |
+
if isinstance(module, bn_module):
|
| 104 |
+
res = cls(module.num_features)
|
| 105 |
+
if module.affine:
|
| 106 |
+
res.weight.data = module.weight.data.clone().detach()
|
| 107 |
+
res.bias.data = module.bias.data.clone().detach()
|
| 108 |
+
res.running_mean.data = module.running_mean.data
|
| 109 |
+
res.running_var.data = module.running_var.data
|
| 110 |
+
res.eps = module.eps
|
| 111 |
+
else:
|
| 112 |
+
for name, child in module.named_children():
|
| 113 |
+
new_child = cls.convert_frozen_batchnorm(child)
|
| 114 |
+
if new_child is not child:
|
| 115 |
+
res.add_module(name, new_child)
|
| 116 |
+
return res
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def get_norm(norm, out_channels):
|
| 120 |
+
"""
|
| 121 |
+
Args:
|
| 122 |
+
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
|
| 123 |
+
or a callable that takes a channel number and returns
|
| 124 |
+
the normalization layer as a nn.Module.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
nn.Module or None: the normalization layer
|
| 128 |
+
"""
|
| 129 |
+
if norm is None:
|
| 130 |
+
return None
|
| 131 |
+
if isinstance(norm, str):
|
| 132 |
+
if len(norm) == 0:
|
| 133 |
+
return None
|
| 134 |
+
norm = {
|
| 135 |
+
"BN": BatchNorm2d,
|
| 136 |
+
# Fixed in https://github.com/pytorch/pytorch/pull/36382
|
| 137 |
+
"SyncBN": nn.SyncBatchNorm,
|
| 138 |
+
"FrozenBN": FrozenBatchNorm2d,
|
| 139 |
+
"GN": lambda channels: nn.GroupNorm(32, channels),
|
| 140 |
+
# for debugging:
|
| 141 |
+
"nnSyncBN": nn.SyncBatchNorm,
|
| 142 |
+
"LN": lambda channels: LayerNorm(channels)
|
| 143 |
+
}[norm]
|
| 144 |
+
return norm(out_channels)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class CycleBatchNormList(nn.ModuleList):
|
| 148 |
+
"""
|
| 149 |
+
Implement domain-specific BatchNorm by cycling.
|
| 150 |
+
|
| 151 |
+
When a BatchNorm layer is used for multiple input domains or input
|
| 152 |
+
features, it might need to maintain a separate test-time statistics
|
| 153 |
+
for each domain. See Sec 5.2 in :paper:`rethinking-batchnorm`.
|
| 154 |
+
|
| 155 |
+
This module implements it by using N separate BN layers
|
| 156 |
+
and it cycles through them every time a forward() is called.
|
| 157 |
+
|
| 158 |
+
NOTE: The caller of this module MUST guarantee to always call
|
| 159 |
+
this module by multiple of N times. Otherwise its test-time statistics
|
| 160 |
+
will be incorrect.
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
def __init__(self, length: int, bn_class=nn.BatchNorm2d, **kwargs):
|
| 164 |
+
"""
|
| 165 |
+
Args:
|
| 166 |
+
length: number of BatchNorm layers to cycle.
|
| 167 |
+
bn_class: the BatchNorm class to use
|
| 168 |
+
kwargs: arguments of the BatchNorm class, such as num_features.
|
| 169 |
+
"""
|
| 170 |
+
self._affine = kwargs.pop("affine", True)
|
| 171 |
+
super().__init__([bn_class(**kwargs, affine=False) for k in range(length)])
|
| 172 |
+
if self._affine:
|
| 173 |
+
# shared affine, domain-specific BN
|
| 174 |
+
channels = self[0].num_features
|
| 175 |
+
self.weight = nn.Parameter(torch.ones(channels))
|
| 176 |
+
self.bias = nn.Parameter(torch.zeros(channels))
|
| 177 |
+
self._pos = 0
|
| 178 |
+
|
| 179 |
+
def forward(self, x):
|
| 180 |
+
ret = self[self._pos](x)
|
| 181 |
+
self._pos = (self._pos + 1) % len(self)
|
| 182 |
+
|
| 183 |
+
if self._affine:
|
| 184 |
+
w = self.weight.reshape(1, -1, 1, 1)
|
| 185 |
+
b = self.bias.reshape(1, -1, 1, 1)
|
| 186 |
+
return ret * w + b
|
| 187 |
+
else:
|
| 188 |
+
return ret
|
| 189 |
+
|
| 190 |
+
def extra_repr(self):
|
| 191 |
+
return f"affine={self._affine}"
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class LayerNorm(nn.Module):
|
| 195 |
+
"""
|
| 196 |
+
A LayerNorm variant, popularized by Transformers, that performs point-wise mean and
|
| 197 |
+
variance normalization over the channel dimension for inputs that have shape
|
| 198 |
+
(batch_size, channels, height, width).
|
| 199 |
+
https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
def __init__(self, normalized_shape, eps=1e-6):
|
| 203 |
+
super().__init__()
|
| 204 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 205 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 206 |
+
self.eps = eps
|
| 207 |
+
self.normalized_shape = (normalized_shape,)
|
| 208 |
+
|
| 209 |
+
def forward(self, x):
|
| 210 |
+
u = x.mean(1, keepdim=True)
|
| 211 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 212 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 213 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 214 |
+
return x
|
models/backbones/eva02/blocks.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
import fvcore.nn.weight_init as weight_init
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from .batch_norm import FrozenBatchNorm2d, get_norm
|
| 8 |
+
from .wrappers import Conv2d
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
"""
|
| 12 |
+
CNN building blocks.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class CNNBlockBase(nn.Module):
|
| 17 |
+
"""
|
| 18 |
+
A CNN block is assumed to have input channels, output channels and a stride.
|
| 19 |
+
The input and output of `forward()` method must be NCHW tensors.
|
| 20 |
+
The method can perform arbitrary computation but must match the given
|
| 21 |
+
channels and stride specification.
|
| 22 |
+
|
| 23 |
+
Attribute:
|
| 24 |
+
in_channels (int):
|
| 25 |
+
out_channels (int):
|
| 26 |
+
stride (int):
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, in_channels, out_channels, stride):
|
| 30 |
+
"""
|
| 31 |
+
The `__init__` method of any subclass should also contain these arguments.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
in_channels (int):
|
| 35 |
+
out_channels (int):
|
| 36 |
+
stride (int):
|
| 37 |
+
"""
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.in_channels = in_channels
|
| 40 |
+
self.out_channels = out_channels
|
| 41 |
+
self.stride = stride
|
| 42 |
+
|
| 43 |
+
def freeze(self):
|
| 44 |
+
"""
|
| 45 |
+
Make this block not trainable.
|
| 46 |
+
This method sets all parameters to `requires_grad=False`,
|
| 47 |
+
and convert all BatchNorm layers to FrozenBatchNorm
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
the block itself
|
| 51 |
+
"""
|
| 52 |
+
for p in self.parameters():
|
| 53 |
+
p.requires_grad = False
|
| 54 |
+
FrozenBatchNorm2d.convert_frozen_batchnorm(self)
|
| 55 |
+
return self
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class DepthwiseSeparableConv2d(nn.Module):
|
| 59 |
+
"""
|
| 60 |
+
A kxk depthwise convolution + a 1x1 convolution.
|
| 61 |
+
|
| 62 |
+
In :paper:`xception`, norm & activation are applied on the second conv.
|
| 63 |
+
:paper:`mobilenet` uses norm & activation on both convs.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
in_channels,
|
| 69 |
+
out_channels,
|
| 70 |
+
kernel_size=3,
|
| 71 |
+
padding=1,
|
| 72 |
+
dilation=1,
|
| 73 |
+
*,
|
| 74 |
+
norm1=None,
|
| 75 |
+
activation1=None,
|
| 76 |
+
norm2=None,
|
| 77 |
+
activation2=None,
|
| 78 |
+
):
|
| 79 |
+
"""
|
| 80 |
+
Args:
|
| 81 |
+
norm1, norm2 (str or callable): normalization for the two conv layers.
|
| 82 |
+
activation1, activation2 (callable(Tensor) -> Tensor): activation
|
| 83 |
+
function for the two conv layers.
|
| 84 |
+
"""
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.depthwise = Conv2d(
|
| 87 |
+
in_channels,
|
| 88 |
+
in_channels,
|
| 89 |
+
kernel_size=kernel_size,
|
| 90 |
+
padding=padding,
|
| 91 |
+
dilation=dilation,
|
| 92 |
+
groups=in_channels,
|
| 93 |
+
bias=not norm1,
|
| 94 |
+
norm=get_norm(norm1, in_channels),
|
| 95 |
+
activation=activation1,
|
| 96 |
+
)
|
| 97 |
+
self.pointwise = Conv2d(
|
| 98 |
+
in_channels,
|
| 99 |
+
out_channels,
|
| 100 |
+
kernel_size=1,
|
| 101 |
+
bias=not norm2,
|
| 102 |
+
norm=get_norm(norm2, out_channels),
|
| 103 |
+
activation=activation2,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# default initialization
|
| 107 |
+
weight_init.c2_msra_fill(self.depthwise)
|
| 108 |
+
weight_init.c2_msra_fill(self.pointwise)
|
| 109 |
+
|
| 110 |
+
def forward(self, x):
|
| 111 |
+
return self.pointwise(self.depthwise(x))
|
models/backbones/eva02/drop.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
|
| 6 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 7 |
+
|
| 8 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| 9 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 10 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| 11 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
| 12 |
+
'survival rate' as the argument.
|
| 13 |
+
|
| 14 |
+
"""
|
| 15 |
+
if drop_prob == 0. or not training:
|
| 16 |
+
return x
|
| 17 |
+
keep_prob = 1 - drop_prob
|
| 18 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 19 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 20 |
+
if keep_prob > 0.0 and scale_by_keep:
|
| 21 |
+
random_tensor.div_(keep_prob)
|
| 22 |
+
return x * random_tensor
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class DropPath(nn.Module):
|
| 26 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 27 |
+
"""
|
| 28 |
+
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
| 29 |
+
super(DropPath, self).__init__()
|
| 30 |
+
self.drop_prob = drop_prob
|
| 31 |
+
self.scale_by_keep = scale_by_keep
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
| 35 |
+
|
| 36 |
+
def extra_repr(self):
|
| 37 |
+
return f'drop_prob={round(self.drop_prob,3):0.3f}'
|
models/backbones/eva02/fpn.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import fvcore.nn.weight_init as weight_init
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _assert_strides_are_log2_contiguous(strides):
|
| 8 |
+
"""
|
| 9 |
+
Assert that each stride is 2x times its preceding stride, i.e. "contiguous in log2".
|
| 10 |
+
"""
|
| 11 |
+
for i, stride in enumerate(strides[1:], 1):
|
| 12 |
+
assert stride == 2 * strides[i - 1], "Strides {} {} are not log2 contiguous".format(
|
| 13 |
+
stride, strides[i - 1]
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class LastLevelMaxPool(nn.Module):
|
| 18 |
+
"""
|
| 19 |
+
This module is used in the original FPN to generate a downsampled
|
| 20 |
+
P6 feature from P5.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.num_levels = 1
|
| 26 |
+
self.in_feature = "p5"
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class LastLevelP6P7(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
This module is used in RetinaNet to generate extra layers, P6 and P7 from
|
| 35 |
+
C5 feature.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, in_channels, out_channels, in_feature="res5"):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.num_levels = 2
|
| 41 |
+
self.in_feature = in_feature
|
| 42 |
+
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
|
| 43 |
+
self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
|
| 44 |
+
for module in [self.p6, self.p7]:
|
| 45 |
+
weight_init.c2_xavier_fill(module)
|
| 46 |
+
|
| 47 |
+
def forward(self, c5):
|
| 48 |
+
p6 = self.p6(c5)
|
| 49 |
+
p7 = self.p7(F.relu(p6))
|
| 50 |
+
return [p6, p7]
|
models/backbones/eva02/main.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from mmcv.runner.checkpoint import load_state_dict
|
| 5 |
+
from mmdet.models.builder import BACKBONES
|
| 6 |
+
from .vit import ViT, SimpleFeaturePyramid, partial
|
| 7 |
+
from .fpn import LastLevelMaxPool
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@BACKBONES.register_module()
|
| 11 |
+
class EVA02(nn.Module):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
# args for ViT
|
| 15 |
+
img_size=1024,
|
| 16 |
+
real_img_size=(256, 704),
|
| 17 |
+
patch_size=16,
|
| 18 |
+
in_chans=3,
|
| 19 |
+
embed_dim=768,
|
| 20 |
+
depth=12,
|
| 21 |
+
num_heads=12,
|
| 22 |
+
mlp_ratio=4*2/3,
|
| 23 |
+
qkv_bias=True,
|
| 24 |
+
drop_path_rate=0.0,
|
| 25 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 26 |
+
use_abs_pos=True,
|
| 27 |
+
pt_hw_seq_len=16,
|
| 28 |
+
intp_freq=True,
|
| 29 |
+
window_size=0,
|
| 30 |
+
window_block_indexes=(),
|
| 31 |
+
residual_block_indexes=(),
|
| 32 |
+
use_act_checkpoint=False,
|
| 33 |
+
pretrain_img_size=224,
|
| 34 |
+
pretrain_use_cls_token=True,
|
| 35 |
+
out_feature="last_feat",
|
| 36 |
+
xattn=False,
|
| 37 |
+
frozen_blocks=-1,
|
| 38 |
+
# args for simple FPN
|
| 39 |
+
fpn_in_feature="last_feat",
|
| 40 |
+
fpn_out_channels=256,
|
| 41 |
+
fpn_scale_factors=(4.0, 2.0, 1.0, 0.5),
|
| 42 |
+
fpn_top_block=False,
|
| 43 |
+
fpn_norm="LN",
|
| 44 |
+
fpn_square_pad=0,
|
| 45 |
+
pretrained=None
|
| 46 |
+
):
|
| 47 |
+
super().__init__()
|
| 48 |
+
|
| 49 |
+
self.backbone = SimpleFeaturePyramid(
|
| 50 |
+
ViT(
|
| 51 |
+
img_size=img_size,
|
| 52 |
+
real_img_size=real_img_size,
|
| 53 |
+
patch_size=patch_size,
|
| 54 |
+
in_chans=in_chans,
|
| 55 |
+
embed_dim=embed_dim,
|
| 56 |
+
depth=depth,
|
| 57 |
+
num_heads=num_heads,
|
| 58 |
+
mlp_ratio=mlp_ratio,
|
| 59 |
+
qkv_bias=qkv_bias,
|
| 60 |
+
drop_path_rate=drop_path_rate,
|
| 61 |
+
norm_layer=norm_layer,
|
| 62 |
+
use_abs_pos=use_abs_pos,
|
| 63 |
+
pt_hw_seq_len=pt_hw_seq_len,
|
| 64 |
+
intp_freq=intp_freq,
|
| 65 |
+
window_size=window_size,
|
| 66 |
+
window_block_indexes=window_block_indexes,
|
| 67 |
+
residual_block_indexes=residual_block_indexes,
|
| 68 |
+
use_act_checkpoint=use_act_checkpoint,
|
| 69 |
+
pretrain_img_size=pretrain_img_size,
|
| 70 |
+
pretrain_use_cls_token=pretrain_use_cls_token,
|
| 71 |
+
out_feature=out_feature,
|
| 72 |
+
xattn=xattn,
|
| 73 |
+
frozen_blocks=frozen_blocks,
|
| 74 |
+
),
|
| 75 |
+
in_feature=fpn_in_feature,
|
| 76 |
+
out_channels=fpn_out_channels,
|
| 77 |
+
scale_factors=fpn_scale_factors,
|
| 78 |
+
top_block=LastLevelMaxPool() if fpn_top_block else None,
|
| 79 |
+
norm=fpn_norm,
|
| 80 |
+
square_pad=fpn_square_pad,
|
| 81 |
+
)
|
| 82 |
+
self.init_weights(pretrained)
|
| 83 |
+
|
| 84 |
+
def init_weights(self, pretrained=None):
|
| 85 |
+
if pretrained is None:
|
| 86 |
+
return
|
| 87 |
+
logging.info('Loading pretrained weights from %s' % pretrained)
|
| 88 |
+
state_dict = torch.load(pretrained)['model']
|
| 89 |
+
load_state_dict(self, state_dict, strict=False)
|
| 90 |
+
|
| 91 |
+
def forward(self, x):
|
| 92 |
+
outs = self.backbone(x)
|
| 93 |
+
return list(outs.values())
|
models/backbones/eva02/utils.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
from scipy import interpolate
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"window_partition",
|
| 11 |
+
"window_unpartition",
|
| 12 |
+
"add_decomposed_rel_pos",
|
| 13 |
+
"get_abs_pos",
|
| 14 |
+
"PatchEmbed",
|
| 15 |
+
"VisionRotaryEmbeddingFast",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def window_partition(x, window_size):
|
| 20 |
+
"""
|
| 21 |
+
Partition into non-overlapping windows with padding if needed.
|
| 22 |
+
Args:
|
| 23 |
+
x (tensor): input tokens with [B, H, W, C].
|
| 24 |
+
window_size (int): window size.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
| 28 |
+
(Hp, Wp): padded height and width before partition
|
| 29 |
+
"""
|
| 30 |
+
B, H, W, C = x.shape
|
| 31 |
+
|
| 32 |
+
pad_h = (window_size - H % window_size) % window_size
|
| 33 |
+
pad_w = (window_size - W % window_size) % window_size
|
| 34 |
+
if pad_h > 0 or pad_w > 0:
|
| 35 |
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
| 36 |
+
Hp, Wp = H + pad_h, W + pad_w
|
| 37 |
+
|
| 38 |
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
| 39 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| 40 |
+
return windows, (Hp, Wp)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def window_unpartition(windows, window_size, pad_hw, hw):
|
| 44 |
+
"""
|
| 45 |
+
Window unpartition into original sequences and removing padding.
|
| 46 |
+
Args:
|
| 47 |
+
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
| 48 |
+
window_size (int): window size.
|
| 49 |
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
| 50 |
+
hw (Tuple): original height and width (H, W) before padding.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
x: unpartitioned sequences with [B, H, W, C].
|
| 54 |
+
"""
|
| 55 |
+
Hp, Wp = pad_hw
|
| 56 |
+
H, W = hw
|
| 57 |
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
| 58 |
+
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
| 59 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
| 60 |
+
|
| 61 |
+
if Hp > H or Wp > W:
|
| 62 |
+
x = x[:, :H, :W, :].contiguous()
|
| 63 |
+
return x
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_rel_pos(q_size, k_size, rel_pos):
|
| 67 |
+
"""
|
| 68 |
+
Get relative positional embeddings according to the relative positions of
|
| 69 |
+
query and key sizes.
|
| 70 |
+
Args:
|
| 71 |
+
q_size (int): size of query q.
|
| 72 |
+
k_size (int): size of key k.
|
| 73 |
+
rel_pos (Tensor): relative position embeddings (L, C).
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Extracted positional embeddings according to relative positions.
|
| 77 |
+
"""
|
| 78 |
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
| 79 |
+
use_log_interpolation = True
|
| 80 |
+
|
| 81 |
+
# Interpolate rel pos if needed.
|
| 82 |
+
if rel_pos.shape[0] != max_rel_dist:
|
| 83 |
+
if not use_log_interpolation:
|
| 84 |
+
# Interpolate rel pos.
|
| 85 |
+
rel_pos_resized = F.interpolate(
|
| 86 |
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
| 87 |
+
size=max_rel_dist,
|
| 88 |
+
mode="linear",
|
| 89 |
+
)
|
| 90 |
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
| 91 |
+
else:
|
| 92 |
+
src_size = rel_pos.shape[0]
|
| 93 |
+
dst_size = max_rel_dist
|
| 94 |
+
|
| 95 |
+
# q = 1.13492
|
| 96 |
+
q = 1.0903078
|
| 97 |
+
dis = []
|
| 98 |
+
|
| 99 |
+
cur = 1
|
| 100 |
+
for i in range(src_size // 2):
|
| 101 |
+
dis.append(cur)
|
| 102 |
+
cur += q ** (i + 1)
|
| 103 |
+
|
| 104 |
+
r_ids = [-_ for _ in reversed(dis)]
|
| 105 |
+
x = r_ids + [0] + dis
|
| 106 |
+
t = dst_size // 2.0
|
| 107 |
+
dx = np.arange(-t, t + 0.1, 1.0)
|
| 108 |
+
# print("x = %s" % str(x))
|
| 109 |
+
# print("dx = %s" % str(dx))
|
| 110 |
+
all_rel_pos_bias = []
|
| 111 |
+
for i in range(rel_pos.shape[1]):
|
| 112 |
+
z = rel_pos[:, i].view(src_size).cpu().float().numpy()
|
| 113 |
+
f = interpolate.interp1d(x, z, kind='cubic', fill_value="extrapolate")
|
| 114 |
+
all_rel_pos_bias.append(
|
| 115 |
+
torch.Tensor(f(dx)).contiguous().view(-1, 1).to(rel_pos.device))
|
| 116 |
+
rel_pos_resized = torch.cat(all_rel_pos_bias, dim=-1)
|
| 117 |
+
else:
|
| 118 |
+
rel_pos_resized = rel_pos
|
| 119 |
+
|
| 120 |
+
# Scale the coords with short length if shapes for q and k are different.
|
| 121 |
+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
| 122 |
+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
| 123 |
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
| 124 |
+
|
| 125 |
+
return rel_pos_resized[relative_coords.long()]
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
|
| 129 |
+
"""
|
| 130 |
+
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
| 131 |
+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
| 132 |
+
Args:
|
| 133 |
+
attn (Tensor): attention map.
|
| 134 |
+
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
| 135 |
+
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
| 136 |
+
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
| 137 |
+
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
| 138 |
+
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
attn (Tensor): attention map with added relative positional embeddings.
|
| 142 |
+
"""
|
| 143 |
+
q_h, q_w = q_size
|
| 144 |
+
k_h, k_w = k_size
|
| 145 |
+
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
| 146 |
+
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
| 147 |
+
|
| 148 |
+
B, _, dim = q.shape
|
| 149 |
+
r_q = q.reshape(B, q_h, q_w, dim)
|
| 150 |
+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
| 151 |
+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
| 152 |
+
|
| 153 |
+
attn = (
|
| 154 |
+
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
| 155 |
+
).view(B, q_h * q_w, k_h * k_w)
|
| 156 |
+
|
| 157 |
+
return attn
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def get_abs_pos(abs_pos, has_cls_token, hw):
|
| 161 |
+
"""
|
| 162 |
+
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
|
| 163 |
+
dimension for the original embeddings.
|
| 164 |
+
Args:
|
| 165 |
+
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
|
| 166 |
+
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
|
| 167 |
+
hw (Tuple): size of input image tokens.
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
Absolute positional embeddings after processing with shape (1, H, W, C)
|
| 171 |
+
"""
|
| 172 |
+
h, w = hw
|
| 173 |
+
if has_cls_token:
|
| 174 |
+
abs_pos = abs_pos[:, 1:]
|
| 175 |
+
xy_num = abs_pos.shape[1]
|
| 176 |
+
size = int(math.sqrt(xy_num))
|
| 177 |
+
assert size * size == xy_num
|
| 178 |
+
|
| 179 |
+
if size != h or size != w:
|
| 180 |
+
new_abs_pos = F.interpolate(
|
| 181 |
+
abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
|
| 182 |
+
size=(h, w),
|
| 183 |
+
mode="bicubic",
|
| 184 |
+
align_corners=False,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
return new_abs_pos.permute(0, 2, 3, 1)
|
| 188 |
+
else:
|
| 189 |
+
return abs_pos.reshape(1, h, w, -1)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class PatchEmbed(nn.Module):
|
| 193 |
+
"""
|
| 194 |
+
Image to Patch Embedding.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
def __init__(
|
| 198 |
+
self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
|
| 199 |
+
):
|
| 200 |
+
"""
|
| 201 |
+
Args:
|
| 202 |
+
kernel_size (Tuple): kernel size of the projection layer.
|
| 203 |
+
stride (Tuple): stride of the projection layer.
|
| 204 |
+
padding (Tuple): padding size of the projection layer.
|
| 205 |
+
in_chans (int): Number of input image channels.
|
| 206 |
+
embed_dim (int): embed_dim (int): Patch embedding dimension.
|
| 207 |
+
"""
|
| 208 |
+
super().__init__()
|
| 209 |
+
|
| 210 |
+
self.proj = nn.Conv2d(
|
| 211 |
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
def forward(self, x):
|
| 215 |
+
x = self.proj(x)
|
| 216 |
+
# B C H W -> B H W C
|
| 217 |
+
x = x.permute(0, 2, 3, 1)
|
| 218 |
+
return x
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
from math import pi
|
| 224 |
+
|
| 225 |
+
import torch
|
| 226 |
+
from torch import nn
|
| 227 |
+
|
| 228 |
+
from einops import rearrange, repeat
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def broadcat(tensors, dim = -1):
|
| 232 |
+
num_tensors = len(tensors)
|
| 233 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
| 234 |
+
assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
|
| 235 |
+
shape_len = list(shape_lens)[0]
|
| 236 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
| 237 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
| 238 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
| 239 |
+
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
|
| 240 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
| 241 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
| 242 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
| 243 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
| 244 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
| 245 |
+
return torch.cat(tensors, dim = dim)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def rotate_half(x):
|
| 249 |
+
x = rearrange(x, '... (d r) -> ... d r', r = 2)
|
| 250 |
+
x1, x2 = x.unbind(dim = -1)
|
| 251 |
+
x = torch.stack((-x2, x1), dim = -1)
|
| 252 |
+
return rearrange(x, '... d r -> ... (d r)')
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class VisionRotaryEmbedding(nn.Module):
|
| 256 |
+
def __init__(
|
| 257 |
+
self,
|
| 258 |
+
dim,
|
| 259 |
+
pt_seq_len,
|
| 260 |
+
ft_seq_len=None,
|
| 261 |
+
custom_freqs = None,
|
| 262 |
+
freqs_for = 'lang',
|
| 263 |
+
theta = 10000,
|
| 264 |
+
max_freq = 10,
|
| 265 |
+
num_freqs = 1,
|
| 266 |
+
):
|
| 267 |
+
super().__init__()
|
| 268 |
+
if custom_freqs:
|
| 269 |
+
freqs = custom_freqs
|
| 270 |
+
elif freqs_for == 'lang':
|
| 271 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
| 272 |
+
elif freqs_for == 'pixel':
|
| 273 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
| 274 |
+
elif freqs_for == 'constant':
|
| 275 |
+
freqs = torch.ones(num_freqs).float()
|
| 276 |
+
else:
|
| 277 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
| 278 |
+
|
| 279 |
+
if ft_seq_len is None: ft_seq_len = pt_seq_len
|
| 280 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
| 281 |
+
|
| 282 |
+
freqs_h = torch.einsum('..., f -> ... f', t, freqs)
|
| 283 |
+
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
|
| 284 |
+
|
| 285 |
+
freqs_w = torch.einsum('..., f -> ... f', t, freqs)
|
| 286 |
+
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
|
| 287 |
+
|
| 288 |
+
freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
|
| 289 |
+
|
| 290 |
+
self.register_buffer("freqs_cos", freqs.cos())
|
| 291 |
+
self.register_buffer("freqs_sin", freqs.sin())
|
| 292 |
+
|
| 293 |
+
print('======== shape of rope freq', self.freqs_cos.shape, '========')
|
| 294 |
+
|
| 295 |
+
def forward(self, t, start_index = 0):
|
| 296 |
+
rot_dim = self.freqs_cos.shape[-1]
|
| 297 |
+
end_index = start_index + rot_dim
|
| 298 |
+
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
|
| 299 |
+
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
|
| 300 |
+
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
|
| 301 |
+
return torch.cat((t_left, t, t_right), dim = -1)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class VisionRotaryEmbeddingFast(nn.Module):
|
| 305 |
+
def __init__(
|
| 306 |
+
self,
|
| 307 |
+
dim,
|
| 308 |
+
pt_seq_len=16,
|
| 309 |
+
ft_seq_len=None,
|
| 310 |
+
custom_freqs = None,
|
| 311 |
+
freqs_for = 'lang',
|
| 312 |
+
theta = 10000,
|
| 313 |
+
max_freq = 10,
|
| 314 |
+
num_freqs = 1,
|
| 315 |
+
real_img_size = None
|
| 316 |
+
):
|
| 317 |
+
super().__init__()
|
| 318 |
+
if custom_freqs:
|
| 319 |
+
freqs = custom_freqs
|
| 320 |
+
elif freqs_for == 'lang':
|
| 321 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
| 322 |
+
elif freqs_for == 'pixel':
|
| 323 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
| 324 |
+
elif freqs_for == 'constant':
|
| 325 |
+
freqs = torch.ones(num_freqs).float()
|
| 326 |
+
else:
|
| 327 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
| 328 |
+
|
| 329 |
+
if ft_seq_len is None: ft_seq_len = pt_seq_len
|
| 330 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
| 331 |
+
|
| 332 |
+
freqs = torch.einsum('..., f -> ... f', t, freqs)
|
| 333 |
+
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
|
| 334 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
|
| 335 |
+
|
| 336 |
+
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
|
| 337 |
+
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
|
| 338 |
+
|
| 339 |
+
if real_img_size is not None:
|
| 340 |
+
new_freqs_cos = F.interpolate(
|
| 341 |
+
freqs_cos.reshape(1, ft_seq_len, ft_seq_len, -1).permute(0, 3, 1, 2),
|
| 342 |
+
size=real_img_size,
|
| 343 |
+
mode="bicubic",
|
| 344 |
+
align_corners=False,
|
| 345 |
+
).permute(0, 2, 3, 1)
|
| 346 |
+
|
| 347 |
+
new_freqs_sin = F.interpolate(
|
| 348 |
+
freqs_sin.reshape(1, ft_seq_len, ft_seq_len, -1).permute(0, 3, 1, 2),
|
| 349 |
+
size=real_img_size,
|
| 350 |
+
mode="bicubic",
|
| 351 |
+
align_corners=False,
|
| 352 |
+
).permute(0, 2, 3, 1)
|
| 353 |
+
|
| 354 |
+
self.register_buffer("freqs_cos", new_freqs_cos.view(-1, freqs.shape[-1]))
|
| 355 |
+
self.register_buffer("freqs_sin", new_freqs_sin.view(-1, freqs.shape[-1]))
|
| 356 |
+
else:
|
| 357 |
+
self.register_buffer("freqs_cos", freqs_cos)
|
| 358 |
+
self.register_buffer("freqs_sin", freqs_sin)
|
| 359 |
+
|
| 360 |
+
def forward(self, t):
|
| 361 |
+
return t * self.freqs_cos[:, None, :] + rotate_half(t) * self.freqs_sin[:, None, :]
|
models/backbones/eva02/vit.py
ADDED
|
@@ -0,0 +1,609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
import fvcore.nn.weight_init as weight_init
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torch.utils.checkpoint as cp
|
| 10 |
+
|
| 11 |
+
from .wrappers import Conv2d
|
| 12 |
+
from .batch_norm import get_norm
|
| 13 |
+
from .blocks import CNNBlockBase
|
| 14 |
+
from .fpn import _assert_strides_are_log2_contiguous
|
| 15 |
+
|
| 16 |
+
from .backbone import Backbone
|
| 17 |
+
from .utils import (
|
| 18 |
+
PatchEmbed,
|
| 19 |
+
add_decomposed_rel_pos,
|
| 20 |
+
get_abs_pos,
|
| 21 |
+
window_partition,
|
| 22 |
+
window_unpartition,
|
| 23 |
+
VisionRotaryEmbeddingFast,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
| 28 |
+
except ImportError:
|
| 29 |
+
flash_attn_func = None
|
| 30 |
+
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class SwiGLU(nn.Module):
|
| 35 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
|
| 36 |
+
norm_layer=nn.LayerNorm, subln=False
|
| 37 |
+
):
|
| 38 |
+
super().__init__()
|
| 39 |
+
out_features = out_features or in_features
|
| 40 |
+
hidden_features = hidden_features or in_features
|
| 41 |
+
|
| 42 |
+
self.w1 = nn.Linear(in_features, hidden_features)
|
| 43 |
+
self.w2 = nn.Linear(in_features, hidden_features)
|
| 44 |
+
|
| 45 |
+
self.act = act_layer()
|
| 46 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
| 47 |
+
self.w3 = nn.Linear(hidden_features, out_features)
|
| 48 |
+
|
| 49 |
+
self.drop = nn.Dropout(drop)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
x1 = self.w1(x)
|
| 53 |
+
x2 = self.w2(x)
|
| 54 |
+
hidden = self.act(x1) * x2
|
| 55 |
+
x = self.ffn_ln(hidden)
|
| 56 |
+
x = self.w3(x)
|
| 57 |
+
x = self.drop(x)
|
| 58 |
+
return x
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Attention(nn.Module):
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
dim,
|
| 65 |
+
num_heads=8,
|
| 66 |
+
qkv_bias=True,
|
| 67 |
+
qk_scale=None,
|
| 68 |
+
attn_head_dim=None,
|
| 69 |
+
rope=None,
|
| 70 |
+
xattn=True,
|
| 71 |
+
):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.num_heads = num_heads
|
| 74 |
+
head_dim = dim // num_heads
|
| 75 |
+
if attn_head_dim is not None:
|
| 76 |
+
head_dim = attn_head_dim
|
| 77 |
+
all_head_dim = head_dim * self.num_heads
|
| 78 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 79 |
+
|
| 80 |
+
self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
|
| 81 |
+
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
|
| 82 |
+
self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
|
| 83 |
+
|
| 84 |
+
if qkv_bias:
|
| 85 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 86 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 87 |
+
else:
|
| 88 |
+
self.q_bias = None
|
| 89 |
+
self.v_bias = None
|
| 90 |
+
|
| 91 |
+
self.rope = rope
|
| 92 |
+
self.xattn = xattn
|
| 93 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
| 94 |
+
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
B, H, W, C = x.shape
|
| 97 |
+
x = x.view(B, -1, C)
|
| 98 |
+
N = H * W
|
| 99 |
+
|
| 100 |
+
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
|
| 101 |
+
k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
|
| 102 |
+
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
|
| 103 |
+
|
| 104 |
+
q = q.reshape(B, N, self.num_heads, -1)
|
| 105 |
+
k = k.reshape(B, N, self.num_heads, -1)
|
| 106 |
+
v = v.reshape(B, N, self.num_heads, -1)
|
| 107 |
+
|
| 108 |
+
## rope
|
| 109 |
+
q = self.rope(q).type_as(v)
|
| 110 |
+
k = self.rope(k).type_as(v)
|
| 111 |
+
|
| 112 |
+
if self.xattn:
|
| 113 |
+
x = flash_attn_func(q, k, v).reshape(B, N, -1)
|
| 114 |
+
else:
|
| 115 |
+
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C
|
| 116 |
+
k = k.permute(0, 2, 1, 3) # B, num_heads, N, C
|
| 117 |
+
v = v.permute(0, 2, 1, 3) # B, num_heads, N, C
|
| 118 |
+
x = F.scaled_dot_product_attention(q, k, v)
|
| 119 |
+
x = x.transpose(1, 2).reshape(B, N, -1)
|
| 120 |
+
|
| 121 |
+
x = self.proj(x)
|
| 122 |
+
x = x.view(B, H, W, C)
|
| 123 |
+
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class ResBottleneckBlock(CNNBlockBase):
|
| 128 |
+
"""
|
| 129 |
+
The standard bottleneck residual block without the last activation layer.
|
| 130 |
+
It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
def __init__(
|
| 134 |
+
self,
|
| 135 |
+
in_channels,
|
| 136 |
+
out_channels,
|
| 137 |
+
bottleneck_channels,
|
| 138 |
+
norm="LN",
|
| 139 |
+
act_layer=nn.GELU,
|
| 140 |
+
):
|
| 141 |
+
"""
|
| 142 |
+
Args:
|
| 143 |
+
in_channels (int): Number of input channels.
|
| 144 |
+
out_channels (int): Number of output channels.
|
| 145 |
+
bottleneck_channels (int): number of output channels for the 3x3
|
| 146 |
+
"bottleneck" conv layers.
|
| 147 |
+
norm (str or callable): normalization for all conv layers.
|
| 148 |
+
See :func:`layers.get_norm` for supported format.
|
| 149 |
+
act_layer (callable): activation for all conv layers.
|
| 150 |
+
"""
|
| 151 |
+
super().__init__(in_channels, out_channels, 1)
|
| 152 |
+
|
| 153 |
+
self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
|
| 154 |
+
self.norm1 = get_norm(norm, bottleneck_channels)
|
| 155 |
+
self.act1 = act_layer()
|
| 156 |
+
|
| 157 |
+
self.conv2 = Conv2d(
|
| 158 |
+
bottleneck_channels,
|
| 159 |
+
bottleneck_channels,
|
| 160 |
+
3,
|
| 161 |
+
padding=1,
|
| 162 |
+
bias=False,
|
| 163 |
+
)
|
| 164 |
+
self.norm2 = get_norm(norm, bottleneck_channels)
|
| 165 |
+
self.act2 = act_layer()
|
| 166 |
+
|
| 167 |
+
self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
|
| 168 |
+
self.norm3 = get_norm(norm, out_channels)
|
| 169 |
+
|
| 170 |
+
for layer in [self.conv1, self.conv2, self.conv3]:
|
| 171 |
+
weight_init.c2_msra_fill(layer)
|
| 172 |
+
for layer in [self.norm1, self.norm2]:
|
| 173 |
+
layer.weight.data.fill_(1.0)
|
| 174 |
+
layer.bias.data.zero_()
|
| 175 |
+
# zero init last norm layer.
|
| 176 |
+
self.norm3.weight.data.zero_()
|
| 177 |
+
self.norm3.bias.data.zero_()
|
| 178 |
+
|
| 179 |
+
def forward(self, x):
|
| 180 |
+
out = x
|
| 181 |
+
for layer in self.children():
|
| 182 |
+
out = layer(out)
|
| 183 |
+
|
| 184 |
+
out = x + out
|
| 185 |
+
return out
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class Block(nn.Module):
|
| 189 |
+
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
| 190 |
+
|
| 191 |
+
def __init__(
|
| 192 |
+
self,
|
| 193 |
+
dim,
|
| 194 |
+
num_heads,
|
| 195 |
+
mlp_ratio=4*2/3,
|
| 196 |
+
qkv_bias=True,
|
| 197 |
+
drop_path=0.0,
|
| 198 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 199 |
+
window_size=0,
|
| 200 |
+
use_residual_block=False,
|
| 201 |
+
rope=None,
|
| 202 |
+
xattn=True,
|
| 203 |
+
use_act_checkpoint=True,
|
| 204 |
+
):
|
| 205 |
+
"""
|
| 206 |
+
Args:
|
| 207 |
+
dim (int): Number of input channels.
|
| 208 |
+
num_heads (int): Number of attention heads in each ViT block.
|
| 209 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 210 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
| 211 |
+
drop_path (float): Stochastic depth rate.
|
| 212 |
+
norm_layer (nn.Module): Normalization layer.
|
| 213 |
+
act_layer (nn.Module): Activation layer.
|
| 214 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
| 215 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
| 216 |
+
window_size (int): Window size for window attention blocks. If it equals 0, then not
|
| 217 |
+
use window attention.
|
| 218 |
+
use_residual_block (bool): If True, use a residual block after the MLP block.
|
| 219 |
+
input_size (int or None): Input resolution for calculating the relative positional
|
| 220 |
+
parameter size.
|
| 221 |
+
"""
|
| 222 |
+
super().__init__()
|
| 223 |
+
self.norm1 = norm_layer(dim)
|
| 224 |
+
self.attn = Attention(
|
| 225 |
+
dim,
|
| 226 |
+
num_heads=num_heads,
|
| 227 |
+
qkv_bias=qkv_bias,
|
| 228 |
+
rope=rope,
|
| 229 |
+
xattn=xattn,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
from .drop import DropPath
|
| 233 |
+
|
| 234 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 235 |
+
self.norm2 = norm_layer(dim)
|
| 236 |
+
self.mlp = SwiGLU(
|
| 237 |
+
in_features=dim,
|
| 238 |
+
hidden_features=int(dim * mlp_ratio),
|
| 239 |
+
subln=True,
|
| 240 |
+
norm_layer=norm_layer,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
self.window_size = window_size
|
| 244 |
+
|
| 245 |
+
self.use_residual_block = use_residual_block
|
| 246 |
+
if use_residual_block:
|
| 247 |
+
# Use a residual block with bottleneck channel as dim // 2
|
| 248 |
+
self.residual = ResBottleneckBlock(
|
| 249 |
+
in_channels=dim,
|
| 250 |
+
out_channels=dim,
|
| 251 |
+
bottleneck_channels=dim // 2,
|
| 252 |
+
norm="LN",
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
self.use_act_checkpoint = use_act_checkpoint
|
| 256 |
+
|
| 257 |
+
def inner_forward(self, x):
|
| 258 |
+
shortcut = x
|
| 259 |
+
x = self.norm1(x)
|
| 260 |
+
|
| 261 |
+
# Window partition
|
| 262 |
+
if self.window_size > 0:
|
| 263 |
+
H, W = x.shape[1], x.shape[2]
|
| 264 |
+
x, pad_hw = window_partition(x, self.window_size)
|
| 265 |
+
|
| 266 |
+
x = self.attn(x)
|
| 267 |
+
|
| 268 |
+
# Reverse window partition
|
| 269 |
+
if self.window_size > 0:
|
| 270 |
+
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
| 271 |
+
|
| 272 |
+
x = shortcut + self.drop_path(x)
|
| 273 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 274 |
+
|
| 275 |
+
if self.use_residual_block:
|
| 276 |
+
x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
|
| 277 |
+
|
| 278 |
+
return x
|
| 279 |
+
|
| 280 |
+
def forward(self, x):
|
| 281 |
+
if self.training and x.requires_grad and self.use_act_checkpoint:
|
| 282 |
+
return cp.checkpoint(self.inner_forward, x)
|
| 283 |
+
else:
|
| 284 |
+
return self.inner_forward(x)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class ViT(Backbone):
|
| 288 |
+
"""
|
| 289 |
+
This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
|
| 290 |
+
"Exploring Plain Vision Transformer Backbones for Object Detection",
|
| 291 |
+
https://arxiv.org/abs/2203.16527
|
| 292 |
+
"""
|
| 293 |
+
|
| 294 |
+
def __init__(
|
| 295 |
+
self,
|
| 296 |
+
img_size=1024,
|
| 297 |
+
real_img_size=(256, 704),
|
| 298 |
+
patch_size=16,
|
| 299 |
+
in_chans=3,
|
| 300 |
+
embed_dim=768,
|
| 301 |
+
depth=12,
|
| 302 |
+
num_heads=12,
|
| 303 |
+
mlp_ratio=4*2/3,
|
| 304 |
+
qkv_bias=True,
|
| 305 |
+
drop_path_rate=0.0,
|
| 306 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 307 |
+
use_abs_pos=True,
|
| 308 |
+
pt_hw_seq_len=16,
|
| 309 |
+
intp_freq=True,
|
| 310 |
+
window_size=0,
|
| 311 |
+
window_block_indexes=(),
|
| 312 |
+
residual_block_indexes=(),
|
| 313 |
+
use_act_checkpoint=False,
|
| 314 |
+
pretrain_img_size=224,
|
| 315 |
+
pretrain_use_cls_token=True,
|
| 316 |
+
out_feature="last_feat",
|
| 317 |
+
xattn=True,
|
| 318 |
+
frozen_blocks=-1,
|
| 319 |
+
):
|
| 320 |
+
"""
|
| 321 |
+
Args:
|
| 322 |
+
img_size (int): Input image size.
|
| 323 |
+
patch_size (int): Patch size.
|
| 324 |
+
in_chans (int): Number of input image channels.
|
| 325 |
+
embed_dim (int): Patch embedding dimension.
|
| 326 |
+
depth (int): Depth of ViT.
|
| 327 |
+
num_heads (int): Number of attention heads in each ViT block.
|
| 328 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 329 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
| 330 |
+
drop_path_rate (float): Stochastic depth rate.
|
| 331 |
+
norm_layer (nn.Module): Normalization layer.
|
| 332 |
+
act_layer (nn.Module): Activation layer.
|
| 333 |
+
use_abs_pos (bool): If True, use absolute positional embeddings.
|
| 334 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
| 335 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
| 336 |
+
window_size (int): Window size for window attention blocks.
|
| 337 |
+
window_block_indexes (list): Indexes for blocks using window attention.
|
| 338 |
+
residual_block_indexes (list): Indexes for blocks using conv propagation.
|
| 339 |
+
use_act_checkpoint (bool): If True, use activation checkpointing.
|
| 340 |
+
pretrain_img_size (int): input image size for pretraining models.
|
| 341 |
+
pretrain_use_cls_token (bool): If True, pretrainig models use class token.
|
| 342 |
+
out_feature (str): name of the feature from the last block.
|
| 343 |
+
"""
|
| 344 |
+
super().__init__()
|
| 345 |
+
self.pretrain_use_cls_token = pretrain_use_cls_token
|
| 346 |
+
self.frozen_blocks = frozen_blocks
|
| 347 |
+
|
| 348 |
+
self.patch_embed = PatchEmbed(
|
| 349 |
+
kernel_size=(patch_size, patch_size),
|
| 350 |
+
stride=(patch_size, patch_size),
|
| 351 |
+
in_chans=in_chans,
|
| 352 |
+
embed_dim=embed_dim,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
if use_abs_pos:
|
| 356 |
+
# Initialize absolute positional embedding with pretrain image size.
|
| 357 |
+
num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
|
| 358 |
+
num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
|
| 359 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
|
| 360 |
+
else:
|
| 361 |
+
self.pos_embed = None
|
| 362 |
+
|
| 363 |
+
half_head_dim = embed_dim // num_heads // 2
|
| 364 |
+
hw_seq_len = img_size // patch_size
|
| 365 |
+
real_img_size = (real_img_size[0] // patch_size, real_img_size[1] // patch_size)
|
| 366 |
+
|
| 367 |
+
self.rope_win = VisionRotaryEmbeddingFast(
|
| 368 |
+
dim=half_head_dim,
|
| 369 |
+
pt_seq_len=pt_hw_seq_len,
|
| 370 |
+
ft_seq_len=window_size if intp_freq else None,
|
| 371 |
+
)
|
| 372 |
+
self.rope_glb = VisionRotaryEmbeddingFast(
|
| 373 |
+
dim=half_head_dim,
|
| 374 |
+
pt_seq_len=pt_hw_seq_len,
|
| 375 |
+
ft_seq_len=hw_seq_len if intp_freq else None,
|
| 376 |
+
real_img_size=real_img_size
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
# stochastic depth decay rule
|
| 380 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
| 381 |
+
|
| 382 |
+
self.blocks = nn.ModuleList()
|
| 383 |
+
for i in range(depth):
|
| 384 |
+
block = Block(
|
| 385 |
+
dim=embed_dim,
|
| 386 |
+
num_heads=num_heads,
|
| 387 |
+
mlp_ratio=mlp_ratio,
|
| 388 |
+
qkv_bias=qkv_bias,
|
| 389 |
+
drop_path=dpr[i],
|
| 390 |
+
norm_layer=norm_layer,
|
| 391 |
+
window_size=window_size if i in window_block_indexes else 0,
|
| 392 |
+
use_residual_block=i in residual_block_indexes,
|
| 393 |
+
rope=self.rope_win if i in window_block_indexes else self.rope_glb,
|
| 394 |
+
xattn=xattn,
|
| 395 |
+
use_act_checkpoint=use_act_checkpoint
|
| 396 |
+
)
|
| 397 |
+
self.blocks.append(block)
|
| 398 |
+
|
| 399 |
+
self._out_feature_channels = {out_feature: embed_dim}
|
| 400 |
+
self._out_feature_strides = {out_feature: patch_size}
|
| 401 |
+
self._out_features = [out_feature]
|
| 402 |
+
|
| 403 |
+
if self.pos_embed is not None:
|
| 404 |
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
| 405 |
+
|
| 406 |
+
self.apply(self._init_weights)
|
| 407 |
+
|
| 408 |
+
def _init_weights(self, m):
|
| 409 |
+
if isinstance(m, nn.Linear):
|
| 410 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 411 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 412 |
+
nn.init.constant_(m.bias, 0)
|
| 413 |
+
elif isinstance(m, nn.LayerNorm):
|
| 414 |
+
nn.init.constant_(m.bias, 0)
|
| 415 |
+
nn.init.constant_(m.weight, 1.0)
|
| 416 |
+
|
| 417 |
+
def forward(self, x):
|
| 418 |
+
x = self.patch_embed(x)
|
| 419 |
+
if self.pos_embed is not None:
|
| 420 |
+
x = x + get_abs_pos(
|
| 421 |
+
self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
for blk in self.blocks:
|
| 425 |
+
x = blk(x)
|
| 426 |
+
|
| 427 |
+
outputs = {self._out_features[0]: x.permute(0, 3, 1, 2)}
|
| 428 |
+
return outputs
|
| 429 |
+
|
| 430 |
+
def _freeze_stages(self):
|
| 431 |
+
def freeze_module(m):
|
| 432 |
+
m.eval()
|
| 433 |
+
for param in m.parameters():
|
| 434 |
+
param.requires_grad = False
|
| 435 |
+
|
| 436 |
+
if self.frozen_blocks >= 0:
|
| 437 |
+
freeze_module(self.patch_embed)
|
| 438 |
+
self.pos_embed.requires_grad=False
|
| 439 |
+
|
| 440 |
+
for i in range(0, self.frozen_blocks):
|
| 441 |
+
freeze_module(self.blocks[i])
|
| 442 |
+
|
| 443 |
+
def train(self, mode=True):
|
| 444 |
+
super().train(mode)
|
| 445 |
+
self._freeze_stages()
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
class SimpleFeaturePyramid(Backbone):
|
| 449 |
+
"""
|
| 450 |
+
This module implements SimpleFeaturePyramid in :paper:`vitdet`.
|
| 451 |
+
It creates pyramid features built on top of the input feature map.
|
| 452 |
+
"""
|
| 453 |
+
|
| 454 |
+
def __init__(
|
| 455 |
+
self,
|
| 456 |
+
net,
|
| 457 |
+
in_feature,
|
| 458 |
+
out_channels,
|
| 459 |
+
scale_factors,
|
| 460 |
+
top_block=None,
|
| 461 |
+
norm="LN",
|
| 462 |
+
square_pad=0,
|
| 463 |
+
):
|
| 464 |
+
"""
|
| 465 |
+
Args:
|
| 466 |
+
net (Backbone): module representing the subnetwork backbone.
|
| 467 |
+
Must be a subclass of :class:`Backbone`.
|
| 468 |
+
in_feature (str): names of the input feature maps coming
|
| 469 |
+
from the net.
|
| 470 |
+
out_channels (int): number of channels in the output feature maps.
|
| 471 |
+
scale_factors (list[float]): list of scaling factors to upsample or downsample
|
| 472 |
+
the input features for creating pyramid features.
|
| 473 |
+
top_block (nn.Module or None): if provided, an extra operation will
|
| 474 |
+
be performed on the output of the last (smallest resolution)
|
| 475 |
+
pyramid output, and the result will extend the result list. The top_block
|
| 476 |
+
further downsamples the feature map. It must have an attribute
|
| 477 |
+
"num_levels", meaning the number of extra pyramid levels added by
|
| 478 |
+
this block, and "in_feature", which is a string representing
|
| 479 |
+
its input feature (e.g., p5).
|
| 480 |
+
norm (str): the normalization to use.
|
| 481 |
+
square_pad (int): If > 0, require input images to be padded to specific square size.
|
| 482 |
+
"""
|
| 483 |
+
super(SimpleFeaturePyramid, self).__init__()
|
| 484 |
+
assert isinstance(net, Backbone)
|
| 485 |
+
|
| 486 |
+
self.scale_factors = scale_factors
|
| 487 |
+
|
| 488 |
+
input_shapes = net.output_shape()
|
| 489 |
+
strides = [int(input_shapes[in_feature].stride / scale) for scale in scale_factors]
|
| 490 |
+
_assert_strides_are_log2_contiguous(strides)
|
| 491 |
+
|
| 492 |
+
dim = input_shapes[in_feature].channels
|
| 493 |
+
self.stages = []
|
| 494 |
+
use_bias = norm == ""
|
| 495 |
+
for idx, scale in enumerate(scale_factors):
|
| 496 |
+
out_dim = dim
|
| 497 |
+
if scale == 4.0:
|
| 498 |
+
layers = [
|
| 499 |
+
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
|
| 500 |
+
get_norm(norm, dim // 2),
|
| 501 |
+
nn.GELU(),
|
| 502 |
+
nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
|
| 503 |
+
]
|
| 504 |
+
out_dim = dim // 4
|
| 505 |
+
elif scale == 2.0:
|
| 506 |
+
layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)]
|
| 507 |
+
out_dim = dim // 2
|
| 508 |
+
elif scale == 1.0:
|
| 509 |
+
layers = []
|
| 510 |
+
elif scale == 0.5:
|
| 511 |
+
layers = [nn.MaxPool2d(kernel_size=2, stride=2)]
|
| 512 |
+
else:
|
| 513 |
+
raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
|
| 514 |
+
|
| 515 |
+
layers.extend(
|
| 516 |
+
[
|
| 517 |
+
Conv2d(
|
| 518 |
+
out_dim,
|
| 519 |
+
out_channels,
|
| 520 |
+
kernel_size=1,
|
| 521 |
+
bias=use_bias,
|
| 522 |
+
norm=get_norm(norm, out_channels),
|
| 523 |
+
),
|
| 524 |
+
Conv2d(
|
| 525 |
+
out_channels,
|
| 526 |
+
out_channels,
|
| 527 |
+
kernel_size=3,
|
| 528 |
+
padding=1,
|
| 529 |
+
bias=use_bias,
|
| 530 |
+
norm=get_norm(norm, out_channels),
|
| 531 |
+
),
|
| 532 |
+
]
|
| 533 |
+
)
|
| 534 |
+
layers = nn.Sequential(*layers)
|
| 535 |
+
|
| 536 |
+
stage = int(math.log2(strides[idx]))
|
| 537 |
+
self.add_module(f"simfp_{stage}", layers)
|
| 538 |
+
self.stages.append(layers)
|
| 539 |
+
|
| 540 |
+
self.net = net
|
| 541 |
+
self.in_feature = in_feature
|
| 542 |
+
self.top_block = top_block
|
| 543 |
+
# Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"]
|
| 544 |
+
self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides}
|
| 545 |
+
# top block output feature maps.
|
| 546 |
+
if self.top_block is not None:
|
| 547 |
+
for s in range(stage, stage + self.top_block.num_levels):
|
| 548 |
+
self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1)
|
| 549 |
+
|
| 550 |
+
self._out_features = list(self._out_feature_strides.keys())
|
| 551 |
+
self._out_feature_channels = {k: out_channels for k in self._out_features}
|
| 552 |
+
self._size_divisibility = strides[-1]
|
| 553 |
+
self._square_pad = square_pad
|
| 554 |
+
|
| 555 |
+
@property
|
| 556 |
+
def padding_constraints(self):
|
| 557 |
+
return {
|
| 558 |
+
"size_divisiblity": self._size_divisibility,
|
| 559 |
+
"square_size": self._square_pad,
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
def forward(self, x):
|
| 563 |
+
"""
|
| 564 |
+
Args:
|
| 565 |
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
| 566 |
+
|
| 567 |
+
Returns:
|
| 568 |
+
dict[str->Tensor]:
|
| 569 |
+
mapping from feature map name to pyramid feature map tensor
|
| 570 |
+
in high to low resolution order. Returned feature names follow the FPN
|
| 571 |
+
convention: "p<stage>", where stage has stride = 2 ** stage e.g.,
|
| 572 |
+
["p2", "p3", ..., "p6"].
|
| 573 |
+
"""
|
| 574 |
+
bottom_up_features = self.net(x)
|
| 575 |
+
features = bottom_up_features[self.in_feature]
|
| 576 |
+
results = []
|
| 577 |
+
|
| 578 |
+
for stage in self.stages:
|
| 579 |
+
results.append(stage(features))
|
| 580 |
+
|
| 581 |
+
if self.top_block is not None:
|
| 582 |
+
if self.top_block.in_feature in bottom_up_features:
|
| 583 |
+
top_block_in_feature = bottom_up_features[self.top_block.in_feature]
|
| 584 |
+
else:
|
| 585 |
+
top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)]
|
| 586 |
+
results.extend(self.top_block(top_block_in_feature))
|
| 587 |
+
assert len(self._out_features) == len(results)
|
| 588 |
+
return {f: res for f, res in zip(self._out_features, results)}
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
|
| 592 |
+
"""
|
| 593 |
+
Calculate lr decay rate for different ViT blocks.
|
| 594 |
+
Args:
|
| 595 |
+
name (string): parameter name.
|
| 596 |
+
lr_decay_rate (float): base lr decay rate.
|
| 597 |
+
num_layers (int): number of ViT blocks.
|
| 598 |
+
|
| 599 |
+
Returns:
|
| 600 |
+
lr decay rate for the given parameter.
|
| 601 |
+
"""
|
| 602 |
+
layer_id = num_layers + 1
|
| 603 |
+
if name.startswith("backbone"):
|
| 604 |
+
if ".pos_embed" in name or ".patch_embed" in name:
|
| 605 |
+
layer_id = 0
|
| 606 |
+
elif ".blocks." in name and ".residual." not in name:
|
| 607 |
+
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
|
| 608 |
+
|
| 609 |
+
return lr_decay_rate ** (num_layers + 1 - layer_id)
|
models/backbones/eva02/wrappers.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
"""
|
| 3 |
+
Wrappers around on some nn functions, mainly to support empty tensors.
|
| 4 |
+
|
| 5 |
+
Ideally, add support directly in PyTorch to empty tensors in those functions.
|
| 6 |
+
|
| 7 |
+
These can be removed once https://github.com/pytorch/pytorch/issues/12013
|
| 8 |
+
is implemented
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import warnings
|
| 12 |
+
from typing import List, Optional
|
| 13 |
+
import torch
|
| 14 |
+
from torch.nn import functional as F
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def shapes_to_tensor(x: List[int], device: Optional[torch.device] = None) -> torch.Tensor:
|
| 18 |
+
"""
|
| 19 |
+
Turn a list of integer scalars or integer Tensor scalars into a vector,
|
| 20 |
+
in a way that's both traceable and scriptable.
|
| 21 |
+
|
| 22 |
+
In tracing, `x` should be a list of scalar Tensor, so the output can trace to the inputs.
|
| 23 |
+
In scripting or eager, `x` should be a list of int.
|
| 24 |
+
"""
|
| 25 |
+
if torch.jit.is_scripting():
|
| 26 |
+
return torch.as_tensor(x, device=device)
|
| 27 |
+
if torch.jit.is_tracing():
|
| 28 |
+
assert all(
|
| 29 |
+
[isinstance(t, torch.Tensor) for t in x]
|
| 30 |
+
), "Shape should be tensor during tracing!"
|
| 31 |
+
# as_tensor should not be used in tracing because it records a constant
|
| 32 |
+
ret = torch.stack(x)
|
| 33 |
+
if ret.device != device: # avoid recording a hard-coded device if not necessary
|
| 34 |
+
ret = ret.to(device=device)
|
| 35 |
+
return ret
|
| 36 |
+
return torch.as_tensor(x, device=device)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def cat(tensors: List[torch.Tensor], dim: int = 0):
|
| 40 |
+
"""
|
| 41 |
+
Efficient version of torch.cat that avoids a copy if there is only a single element in a list
|
| 42 |
+
"""
|
| 43 |
+
assert isinstance(tensors, (list, tuple))
|
| 44 |
+
if len(tensors) == 1:
|
| 45 |
+
return tensors[0]
|
| 46 |
+
return torch.cat(tensors, dim)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def empty_input_loss_func_wrapper(loss_func):
|
| 50 |
+
def wrapped_loss_func(input, target, *, reduction="mean", **kwargs):
|
| 51 |
+
"""
|
| 52 |
+
Same as `loss_func`, but returns 0 (instead of nan) for empty inputs.
|
| 53 |
+
"""
|
| 54 |
+
if target.numel() == 0 and reduction == "mean":
|
| 55 |
+
return input.sum() * 0.0 # connect the gradient
|
| 56 |
+
return loss_func(input, target, reduction=reduction, **kwargs)
|
| 57 |
+
|
| 58 |
+
return wrapped_loss_func
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
cross_entropy = empty_input_loss_func_wrapper(F.cross_entropy)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class _NewEmptyTensorOp(torch.autograd.Function):
|
| 65 |
+
@staticmethod
|
| 66 |
+
def forward(ctx, x, new_shape):
|
| 67 |
+
ctx.shape = x.shape
|
| 68 |
+
return x.new_empty(new_shape)
|
| 69 |
+
|
| 70 |
+
@staticmethod
|
| 71 |
+
def backward(ctx, grad):
|
| 72 |
+
shape = ctx.shape
|
| 73 |
+
return _NewEmptyTensorOp.apply(grad, shape), None
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class Conv2d(torch.nn.Conv2d):
|
| 77 |
+
"""
|
| 78 |
+
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(self, *args, **kwargs):
|
| 82 |
+
"""
|
| 83 |
+
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
norm (nn.Module, optional): a normalization layer
|
| 87 |
+
activation (callable(Tensor) -> Tensor): a callable activation function
|
| 88 |
+
|
| 89 |
+
It assumes that norm layer is used before activation.
|
| 90 |
+
"""
|
| 91 |
+
norm = kwargs.pop("norm", None)
|
| 92 |
+
activation = kwargs.pop("activation", None)
|
| 93 |
+
super().__init__(*args, **kwargs)
|
| 94 |
+
|
| 95 |
+
self.norm = norm
|
| 96 |
+
self.activation = activation
|
| 97 |
+
|
| 98 |
+
def forward(self, x):
|
| 99 |
+
# torchscript does not support SyncBatchNorm yet
|
| 100 |
+
# https://github.com/pytorch/pytorch/issues/40507
|
| 101 |
+
# and we skip these codes in torchscript since:
|
| 102 |
+
# 1. currently we only support torchscript in evaluation mode
|
| 103 |
+
# 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
|
| 104 |
+
# later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
|
| 105 |
+
if not torch.jit.is_scripting():
|
| 106 |
+
with warnings.catch_warnings(record=True):
|
| 107 |
+
if x.numel() == 0 and self.training:
|
| 108 |
+
# https://github.com/pytorch/pytorch/issues/12013
|
| 109 |
+
assert not isinstance(
|
| 110 |
+
self.norm, torch.nn.SyncBatchNorm
|
| 111 |
+
), "SyncBatchNorm does not support empty inputs!"
|
| 112 |
+
|
| 113 |
+
x = F.conv2d(
|
| 114 |
+
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
| 115 |
+
)
|
| 116 |
+
if self.norm is not None:
|
| 117 |
+
x = self.norm(x)
|
| 118 |
+
if self.activation is not None:
|
| 119 |
+
x = self.activation(x)
|
| 120 |
+
return x
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
ConvTranspose2d = torch.nn.ConvTranspose2d
|
| 124 |
+
BatchNorm2d = torch.nn.BatchNorm2d
|
| 125 |
+
interpolate = F.interpolate
|
| 126 |
+
Linear = torch.nn.Linear
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def nonzero_tuple(x):
|
| 130 |
+
"""
|
| 131 |
+
A 'as_tuple=True' version of torch.nonzero to support torchscript.
|
| 132 |
+
because of https://github.com/pytorch/pytorch/issues/38718
|
| 133 |
+
"""
|
| 134 |
+
if torch.jit.is_scripting():
|
| 135 |
+
if x.dim() == 0:
|
| 136 |
+
return x.unsqueeze(0).nonzero().unbind(1)
|
| 137 |
+
return x.nonzero().unbind(1)
|
| 138 |
+
else:
|
| 139 |
+
return x.nonzero(as_tuple=True)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@torch.jit.script_if_tracing
|
| 143 |
+
def move_device_like(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor:
|
| 144 |
+
"""
|
| 145 |
+
Tracing friendly way to cast tensor to another tensor's device. Device will be treated
|
| 146 |
+
as constant during tracing, scripting the casting process as whole can workaround this issue.
|
| 147 |
+
"""
|
| 148 |
+
return src.to(dst.device)
|
models/sparsebev.py
CHANGED
|
@@ -14,7 +14,7 @@ from .utils import GridMask, pad_multiple, GpuPhotoMetricDistortion
|
|
| 14 |
class SparseBEV(MVXTwoStageDetector):
|
| 15 |
def __init__(self,
|
| 16 |
data_aug=None,
|
| 17 |
-
stop_prev_grad=
|
| 18 |
pts_voxel_layer=None,
|
| 19 |
pts_voxel_encoder=None,
|
| 20 |
pts_middle_encoder=None,
|
|
@@ -99,12 +99,12 @@ class SparseBEV(MVXTwoStageDetector):
|
|
| 99 |
for img_meta in img_metas:
|
| 100 |
img_meta.update(input_shape=input_shape)
|
| 101 |
|
| 102 |
-
if self.training and self.stop_prev_grad:
|
| 103 |
H, W = input_shape
|
| 104 |
img = img.reshape(B, -1, 6, C, H, W)
|
| 105 |
|
| 106 |
-
img_grad = img[:, :
|
| 107 |
-
img_nograd = img[:,
|
| 108 |
|
| 109 |
all_img_feats = [self.extract_img_feat(img_grad.reshape(-1, C, H, W))]
|
| 110 |
|
|
|
|
| 14 |
class SparseBEV(MVXTwoStageDetector):
|
| 15 |
def __init__(self,
|
| 16 |
data_aug=None,
|
| 17 |
+
stop_prev_grad=0,
|
| 18 |
pts_voxel_layer=None,
|
| 19 |
pts_voxel_encoder=None,
|
| 20 |
pts_middle_encoder=None,
|
|
|
|
| 99 |
for img_meta in img_metas:
|
| 100 |
img_meta.update(input_shape=input_shape)
|
| 101 |
|
| 102 |
+
if self.training and self.stop_prev_grad > 0:
|
| 103 |
H, W = input_shape
|
| 104 |
img = img.reshape(B, -1, 6, C, H, W)
|
| 105 |
|
| 106 |
+
img_grad = img[:, :self.stop_prev_grad]
|
| 107 |
+
img_nograd = img[:, self.stop_prev_grad:]
|
| 108 |
|
| 109 |
all_img_feats = [self.extract_img_feat(img_grad.reshape(-1, C, H, W))]
|
| 110 |
|
val.py
CHANGED
|
@@ -112,6 +112,7 @@ def main():
|
|
| 112 |
logging.info('Creating model: %s' % cfgs.model.type)
|
| 113 |
model = build_model(cfgs.model)
|
| 114 |
model.cuda()
|
|
|
|
| 115 |
|
| 116 |
if world_size > 1:
|
| 117 |
model = MMDistributedDataParallel(model, [local_rank], broadcast_buffers=False)
|
|
|
|
| 112 |
logging.info('Creating model: %s' % cfgs.model.type)
|
| 113 |
model = build_model(cfgs.model)
|
| 114 |
model.cuda()
|
| 115 |
+
model.fp16_enabled = True
|
| 116 |
|
| 117 |
if world_size > 1:
|
| 118 |
model = MMDistributedDataParallel(model, [local_rank], broadcast_buffers=False)
|