Haisong Liu commited on
Commit
102ac67
·
unverified ·
1 Parent(s): e610aaf

Release model: vit_eva02_1600x640_trainval_future (#46)

Browse files
README.md CHANGED
@@ -28,6 +28,7 @@ This is the official PyTorch implementation for our ICCV 2023 paper:
28
  | [r50_nuimg_704x256_400q_36ep](configs/r50_nuimg_704x256_400q_36ep.py) | [nuImg](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/nuimages_semseg/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim_20201009_124951-40963960.pth) | 28h (8x2080Ti) | 55.8 | - | 23.5 | [gdrive](https://drive.google.com/file/d/1C_Vn3iiSnSW1Dw1r0DkjJMwvHC5Y3zTN/view) |
29
  | [r101_nuimg_1408x512](configs/r101_nuimg_1408x512.py) | [nuImg](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/nuimages_semseg/cascade_mask_rcnn_r101_fpn_1x_nuim/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth) | 2d8h (8xV100) | 59.2 | - | 6.5 | [gdrive](https://drive.google.com/file/d/1dKu5cR1fuo-O0ynyBh-RCPtHrgut29mN/view) |
30
  | [vov99_dd3d_1600x640_trainval_future](configs/vov99_dd3d_1600x640_trainval_future.py) | [DD3D](https://drive.google.com/file/d/1gQkhWERCzAosBwG5bh2BKkt1k0TJZt-A/view) | 4d1h (8xA100) | 84.9 | 67.5 | - | [gdrive](https://drive.google.com/file/d/1TL0QoCiWD5uq8PCAWWE3A-g73ibK1R0S/view) |
 
31
 
32
  * We use `r50_nuimg_704x256` for ablation studies and `r50_nuimg_704x256_400q_36ep` for comparison with others.
33
  * We recommend using `r50_nuimg_704x256` to validate new ideas since it trains faster and the result is more stable.
 
28
  | [r50_nuimg_704x256_400q_36ep](configs/r50_nuimg_704x256_400q_36ep.py) | [nuImg](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/nuimages_semseg/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim_20201009_124951-40963960.pth) | 28h (8x2080Ti) | 55.8 | - | 23.5 | [gdrive](https://drive.google.com/file/d/1C_Vn3iiSnSW1Dw1r0DkjJMwvHC5Y3zTN/view) |
29
  | [r101_nuimg_1408x512](configs/r101_nuimg_1408x512.py) | [nuImg](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/nuimages_semseg/cascade_mask_rcnn_r101_fpn_1x_nuim/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth) | 2d8h (8xV100) | 59.2 | - | 6.5 | [gdrive](https://drive.google.com/file/d/1dKu5cR1fuo-O0ynyBh-RCPtHrgut29mN/view) |
30
  | [vov99_dd3d_1600x640_trainval_future](configs/vov99_dd3d_1600x640_trainval_future.py) | [DD3D](https://drive.google.com/file/d/1gQkhWERCzAosBwG5bh2BKkt1k0TJZt-A/view) | 4d1h (8xA100) | 84.9 | 67.5 | - | [gdrive](https://drive.google.com/file/d/1TL0QoCiWD5uq8PCAWWE3A-g73ibK1R0S/view) |
31
+ | [vit_eva02_1600x640_trainval_future](configs/vit_eva02_1600x640_trainval_future.py) | [EVA02](https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/det/eva02_L_coco_seg_sys_o365.pth) | 11d (8xA100) | 85.3 | 70.2 | - | [gdrive](https://drive.google.com/file/d/1cx7h6PUqiaVWPixpcuB9AhsX3Sx4n0q_/view) |
32
 
33
  * We use `r50_nuimg_704x256` for ablation studies and `r50_nuimg_704x256_400q_36ep` for comparison with others.
34
  * We recommend using `r50_nuimg_704x256` to validate new ideas since it trains faster and the result is more stable.
configs/r50_nuimg_704x256.py CHANGED
@@ -54,7 +54,7 @@ model = dict(
54
  img_color_aug=True, # Move some augmentations to GPU
55
  img_norm_cfg=img_norm_cfg,
56
  img_pad_cfg=dict(size_divisor=32)),
57
- stop_prev_grad=False,
58
  img_backbone=img_backbone,
59
  img_neck=img_neck,
60
  pts_bbox_head=dict(
 
54
  img_color_aug=True, # Move some augmentations to GPU
55
  img_norm_cfg=img_norm_cfg,
56
  img_pad_cfg=dict(size_divisor=32)),
57
+ stop_prev_grad=0,
58
  img_backbone=img_backbone,
59
  img_neck=img_neck,
60
  pts_bbox_head=dict(
configs/vit_eva02_1600x640_trainval_future.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = ['./r50_nuimg_704x256.py']
2
+
3
+ # For nuScenes we usually do 10-class detection
4
+ class_names = [
5
+ 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
6
+ 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
7
+ ]
8
+
9
+ # If point cloud range is changed, the models should also change their point
10
+ # cloud range accordingly
11
+ point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
12
+ voxel_size = [0.2, 0.2, 8]
13
+
14
+ img_backbone = dict(
15
+ _delete_=True,
16
+ type='EVA02',
17
+ img_size=1536,
18
+ real_img_size=(640, 1600),
19
+ patch_size=16,
20
+ in_chans=3,
21
+ embed_dim=1024,
22
+ depth=24,
23
+ num_heads=16,
24
+ mlp_ratio=4*2/3,
25
+ qkv_bias=True,
26
+ drop_path_rate=0.3,
27
+ use_abs_pos=True,
28
+ window_size=16,
29
+ window_block_indexes=(
30
+ list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + list(range(12, 14)) + list(range(15, 17)) + list(range(18, 20)) + list(range(21, 23))
31
+ ),
32
+ residual_block_indexes=(),
33
+ use_act_checkpoint=True,
34
+ # args for simple FPN
35
+ fpn_out_channels=256,
36
+ fpn_scale_factors=(4.0, 2.0, 1.0, 0.5),
37
+ fpn_top_block=True,
38
+ fpn_norm="LN",
39
+ fpn_square_pad=1600,
40
+ pretrained='pretrain/eva02_L_coco_seg_sys_o365.pth',
41
+ frozen_blocks=3,
42
+ )
43
+ img_norm_cfg = dict(
44
+ mean=[123.675, 116.280, 103.530],
45
+ std=[58.395, 57.120, 57.375],
46
+ to_rgb=True
47
+ )
48
+
49
+ model = dict(
50
+ img_backbone=img_backbone,
51
+ img_neck=None,
52
+ stop_prev_grad=4,
53
+ pts_bbox_head=dict(
54
+ num_query=1600,
55
+ transformer=dict(
56
+ num_levels=5,
57
+ num_points=8,
58
+ num_frames=15))
59
+ )
60
+
61
+ ida_aug_conf = {
62
+ 'resize_lim': (0.94, 1.25),
63
+ 'final_dim': (640, 1600),
64
+ 'bot_pct_lim': (0.0, 0.0),
65
+ 'rot_lim': (0.0, 0.0),
66
+ 'H': 900, 'W': 1600,
67
+ 'rand_flip': True,
68
+ }
69
+
70
+ train_pipeline = [
71
+ dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),
72
+ dict(type='LoadMultiViewImageFromMultiSweepsFutureInterleave', prev_sweeps_num=7, next_sweeps_num=7),
73
+ dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True, with_attr_label=False),
74
+ dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
75
+ dict(type='ObjectNameFilter', classes=class_names),
76
+ dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=True),
77
+ dict(type='GlobalRotScaleTransImage', rot_range=[-0.3925, 0.3925], scale_ratio_range=[0.95, 1.05]),
78
+ dict(type='DefaultFormatBundle3D', class_names=class_names),
79
+ dict(type='Collect3D', keys=['gt_bboxes_3d', 'gt_labels_3d', 'img'], meta_keys=(
80
+ 'filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp'))
81
+ ]
82
+
83
+ test_pipeline = [
84
+ dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),
85
+ dict(type='LoadMultiViewImageFromMultiSweepsFutureInterleave', prev_sweeps_num=7, next_sweeps_num=7, test_mode=True),
86
+ dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=False),
87
+ dict(
88
+ type='MultiScaleFlipAug3D',
89
+ img_scale=(1600, 900),
90
+ pts_scale_ratio=1,
91
+ flip=False,
92
+ transforms=[
93
+ dict(type='DefaultFormatBundle3D', class_names=class_names, with_label=False),
94
+ dict(type='Collect3D', keys=['img'], meta_keys=(
95
+ 'filename', 'box_type_3d', 'ori_shape', 'img_shape', 'pad_shape',
96
+ 'lidar2img', 'img_timestamp'))
97
+ ])
98
+ ]
99
+
100
+ data = dict(
101
+ train=dict(
102
+ ann_file=['data/nuscenes/nuscenes_infos_train_sweep.pkl',
103
+ 'data/nuscenes/nuscenes_infos_val_sweep.pkl'],
104
+ pipeline=train_pipeline),
105
+ val=dict(
106
+ ann_file='data/nuscenes/nuscenes_infos_val_sweep.pkl', # use nuscenes_infos_test_sweep.pkl for submission
107
+ pipeline=test_pipeline),
108
+ test=dict(pipeline=test_pipeline)
109
+ )
110
+
111
+ load_from = None
112
+ revise_keys = None
loaders/pipelines/loading.py CHANGED
@@ -255,3 +255,138 @@ class LoadMultiViewImageFromMultiSweepsFuture(object):
255
  ))
256
 
257
  return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  ))
256
 
257
  return results
258
+
259
+
260
+ '''
261
+ This func loads previous and future frames in interleaved order,
262
+ e.g. curr, prev1, next1, prev2, next2, prev3, next3...
263
+ '''
264
+ @PIPELINES.register_module()
265
+ class LoadMultiViewImageFromMultiSweepsFutureInterleave(object):
266
+ def __init__(self,
267
+ prev_sweeps_num=5,
268
+ next_sweeps_num=5,
269
+ color_type='color',
270
+ test_mode=False):
271
+ self.prev_sweeps_num = prev_sweeps_num
272
+ self.next_sweeps_num = next_sweeps_num
273
+ self.color_type = color_type
274
+ self.test_mode = test_mode
275
+
276
+ assert prev_sweeps_num == next_sweeps_num
277
+
278
+ self.train_interval = [4, 8]
279
+ self.test_interval = 6
280
+
281
+ try:
282
+ mmcv.use_backend('turbojpeg')
283
+ except ImportError:
284
+ mmcv.use_backend('cv2')
285
+
286
+ def __call__(self, results):
287
+ if self.prev_sweeps_num == 0 and self.next_sweeps_num == 0:
288
+ return results
289
+
290
+ cam_types = [
291
+ 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT',
292
+ 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT'
293
+ ]
294
+
295
+ if self.test_mode:
296
+ interval = self.test_interval
297
+ else:
298
+ interval = np.random.randint(self.train_interval[0], self.train_interval[1] + 1)
299
+
300
+ results_prev = dict(
301
+ img=[],
302
+ img_timestamp=[],
303
+ filename=[],
304
+ lidar2img=[]
305
+ )
306
+ results_next = dict(
307
+ img=[],
308
+ img_timestamp=[],
309
+ filename=[],
310
+ lidar2img=[]
311
+ )
312
+
313
+ if len(results['sweeps']['prev']) == 0:
314
+ for _ in range(self.prev_sweeps_num):
315
+ for j in range(len(cam_types)):
316
+ results_prev['img'].append(results['img'][j])
317
+ results_prev['img_timestamp'].append(results['img_timestamp'][j])
318
+ results_prev['filename'].append(results['filename'][j])
319
+ results_prev['lidar2img'].append(np.copy(results['lidar2img'][j]))
320
+ else:
321
+ choices = [(k + 1) * interval - 1 for k in range(self.prev_sweeps_num)]
322
+
323
+ for idx in sorted(list(choices)):
324
+ sweep_idx = min(idx, len(results['sweeps']['prev']) - 1)
325
+ sweep = results['sweeps']['prev'][sweep_idx]
326
+
327
+ if len(sweep.keys()) < len(cam_types):
328
+ sweep = results['sweeps']['prev'][sweep_idx - 1]
329
+
330
+ for sensor in cam_types:
331
+ results_prev['img'].append(mmcv.imread(sweep[sensor]['data_path'], self.color_type))
332
+ results_prev['img_timestamp'].append(sweep[sensor]['timestamp'] / 1e6)
333
+ results_prev['filename'].append(os.path.relpath(sweep[sensor]['data_path']))
334
+ results_prev['lidar2img'].append(compose_lidar2img(
335
+ results['ego2global_translation'],
336
+ results['ego2global_rotation'],
337
+ results['lidar2ego_translation'],
338
+ results['lidar2ego_rotation'],
339
+ sweep[sensor]['sensor2global_translation'],
340
+ sweep[sensor]['sensor2global_rotation'],
341
+ sweep[sensor]['cam_intrinsic'],
342
+ ))
343
+
344
+ if len(results['sweeps']['next']) == 0:
345
+ print(1, len(results_next['img']) )
346
+ for _ in range(self.next_sweeps_num):
347
+ for j in range(len(cam_types)):
348
+ results_next['img'].append(results['img'][j])
349
+ results_next['img_timestamp'].append(results['img_timestamp'][j])
350
+ results_next['filename'].append(results['filename'][j])
351
+ results_next['lidar2img'].append(np.copy(results['lidar2img'][j]))
352
+ else:
353
+ choices = [(k + 1) * interval - 1 for k in range(self.next_sweeps_num)]
354
+
355
+ for idx in sorted(list(choices)):
356
+ sweep_idx = min(idx, len(results['sweeps']['next']) - 1)
357
+ sweep = results['sweeps']['next'][sweep_idx]
358
+
359
+ if len(sweep.keys()) < len(cam_types):
360
+ sweep = results['sweeps']['next'][sweep_idx - 1]
361
+
362
+ for sensor in cam_types:
363
+ results_next['img'].append(mmcv.imread(sweep[sensor]['data_path'], self.color_type))
364
+ results_next['img_timestamp'].append(sweep[sensor]['timestamp'] / 1e6)
365
+ results_next['filename'].append(os.path.relpath(sweep[sensor]['data_path']))
366
+ results_next['lidar2img'].append(compose_lidar2img(
367
+ results['ego2global_translation'],
368
+ results['ego2global_rotation'],
369
+ results['lidar2ego_translation'],
370
+ results['lidar2ego_rotation'],
371
+ sweep[sensor]['sensor2global_translation'],
372
+ sweep[sensor]['sensor2global_rotation'],
373
+ sweep[sensor]['cam_intrinsic'],
374
+ ))
375
+
376
+ assert len(results_prev['img']) % 6 == 0
377
+ assert len(results_next['img']) % 6 == 0
378
+
379
+ for i in range(len(results_prev['img']) // 6):
380
+ for j in range(6):
381
+ results['img'].append(results_prev['img'][i * 6 + j])
382
+ results['img_timestamp'].append(results_prev['img_timestamp'][i * 6 + j])
383
+ results['filename'].append(results_prev['filename'][i * 6 + j])
384
+ results['lidar2img'].append(results_prev['lidar2img'][i * 6 + j])
385
+
386
+ for j in range(6):
387
+ results['img'].append(results_next['img'][i * 6 + j])
388
+ results['img_timestamp'].append(results_next['img_timestamp'][i * 6 + j])
389
+ results['filename'].append(results_next['filename'][i * 6 + j])
390
+ results['lidar2img'].append(results_next['lidar2img'][i * 6 + j])
391
+
392
+ return results
models/backbones/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
  from .vovnet import VoVNet
 
2
 
3
- __all__ = ['VoVNet']
 
1
  from .vovnet import VoVNet
2
+ from .eva02 import EVA02
3
 
4
+ __all__ = ['VoVNet', 'EVA02']
models/backbones/eva02/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .main import EVA02
models/backbones/eva02/backbone.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/shape_spec.py
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ from dataclasses import dataclass
5
+ from typing import Optional
6
+
7
+
8
+ @dataclass
9
+ class ShapeSpec:
10
+ """
11
+ A simple structure that contains basic shape specification about a tensor.
12
+ It is often used as the auxiliary inputs/outputs of models,
13
+ to complement the lack of shape inference ability among pytorch modules.
14
+ """
15
+
16
+ channels: Optional[int] = None
17
+ height: Optional[int] = None
18
+ width: Optional[int] = None
19
+ stride: Optional[int] = None
20
+
21
+
22
+ # Copyright (c) Facebook, Inc. and its affiliates.
23
+ from abc import ABCMeta, abstractmethod
24
+ from typing import Dict
25
+ import torch.nn as nn
26
+
27
+
28
+ __all__ = ["Backbone"]
29
+
30
+
31
+ class Backbone(nn.Module, metaclass=ABCMeta):
32
+ """
33
+ Abstract base class for network backbones.
34
+ """
35
+
36
+ def __init__(self):
37
+ """
38
+ The `__init__` method of any subclass can specify its own set of arguments.
39
+ """
40
+ super().__init__()
41
+
42
+ @abstractmethod
43
+ def forward(self):
44
+ """
45
+ Subclasses must override this method, but adhere to the same return type.
46
+
47
+ Returns:
48
+ dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor
49
+ """
50
+ pass
51
+
52
+ @property
53
+ def size_divisibility(self) -> int:
54
+ """
55
+ Some backbones require the input height and width to be divisible by a
56
+ specific integer. This is typically true for encoder / decoder type networks
57
+ with lateral connection (e.g., FPN) for which feature maps need to match
58
+ dimension in the "bottom up" and "top down" paths. Set to 0 if no specific
59
+ input size divisibility is required.
60
+ """
61
+ return 0
62
+
63
+ @property
64
+ def padding_constraints(self) -> Dict[str, int]:
65
+ """
66
+ This property is a generalization of size_divisibility. Some backbones and training
67
+ recipes require specific padding constraints, such as enforcing divisibility by a specific
68
+ integer (e.g., FPN) or padding to a square (e.g., ViTDet with large-scale jitter
69
+ in :paper:vitdet). `padding_constraints` contains these optional items like:
70
+ {
71
+ "size_divisibility": int,
72
+ "square_size": int,
73
+ # Future options are possible
74
+ }
75
+ `size_divisibility` will read from here if presented and `square_size` indicates the
76
+ square padding size if `square_size` > 0.
77
+
78
+ TODO: use type of Dict[str, int] to avoid torchscipt issues. The type of padding_constraints
79
+ could be generalized as TypedDict (Python 3.8+) to support more types in the future.
80
+ """
81
+ return {}
82
+
83
+ def output_shape(self):
84
+ """
85
+ Returns:
86
+ dict[str->ShapeSpec]
87
+ """
88
+ # this is a backward-compatible default
89
+ return {
90
+ name: ShapeSpec(
91
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
92
+ )
93
+ for name in self._out_features
94
+ }
models/backbones/eva02/batch_norm.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import torch
3
+ import torch.distributed as dist
4
+ from fvcore.nn.distributed import differentiable_all_reduce
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from .wrappers import BatchNorm2d
9
+
10
+
11
+ class FrozenBatchNorm2d(nn.Module):
12
+ """
13
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
14
+
15
+ It contains non-trainable buffers called
16
+ "weight" and "bias", "running_mean", "running_var",
17
+ initialized to perform identity transformation.
18
+
19
+ The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
20
+ which are computed from the original four parameters of BN.
21
+ The affine transform `x * weight + bias` will perform the equivalent
22
+ computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
23
+ When loading a backbone model from Caffe2, "running_mean" and "running_var"
24
+ will be left unchanged as identity transformation.
25
+
26
+ Other pre-trained backbone models may contain all 4 parameters.
27
+
28
+ The forward is implemented by `F.batch_norm(..., training=False)`.
29
+ """
30
+
31
+ _version = 3
32
+
33
+ def __init__(self, num_features, eps=1e-5):
34
+ super().__init__()
35
+ self.num_features = num_features
36
+ self.eps = eps
37
+ self.register_buffer("weight", torch.ones(num_features))
38
+ self.register_buffer("bias", torch.zeros(num_features))
39
+ self.register_buffer("running_mean", torch.zeros(num_features))
40
+ self.register_buffer("running_var", torch.ones(num_features) - eps)
41
+
42
+ def forward(self, x):
43
+ if x.requires_grad:
44
+ # When gradients are needed, F.batch_norm will use extra memory
45
+ # because its backward op computes gradients for weight/bias as well.
46
+ scale = self.weight * (self.running_var + self.eps).rsqrt()
47
+ bias = self.bias - self.running_mean * scale
48
+ scale = scale.reshape(1, -1, 1, 1)
49
+ bias = bias.reshape(1, -1, 1, 1)
50
+ out_dtype = x.dtype # may be half
51
+ return x * scale.to(out_dtype) + bias.to(out_dtype)
52
+ else:
53
+ # When gradients are not needed, F.batch_norm is a single fused op
54
+ # and provide more optimization opportunities.
55
+ return F.batch_norm(
56
+ x,
57
+ self.running_mean,
58
+ self.running_var,
59
+ self.weight,
60
+ self.bias,
61
+ training=False,
62
+ eps=self.eps,
63
+ )
64
+
65
+ def _load_from_state_dict(
66
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
67
+ ):
68
+ version = local_metadata.get("version", None)
69
+
70
+ if version is None or version < 2:
71
+ # No running_mean/var in early versions
72
+ # This will silent the warnings
73
+ if prefix + "running_mean" not in state_dict:
74
+ state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
75
+ if prefix + "running_var" not in state_dict:
76
+ state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
77
+
78
+ super()._load_from_state_dict(
79
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
80
+ )
81
+
82
+ def __repr__(self):
83
+ return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
84
+
85
+ @classmethod
86
+ def convert_frozen_batchnorm(cls, module):
87
+ """
88
+ Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
89
+
90
+ Args:
91
+ module (torch.nn.Module):
92
+
93
+ Returns:
94
+ If module is BatchNorm/SyncBatchNorm, returns a new module.
95
+ Otherwise, in-place convert module and return it.
96
+
97
+ Similar to convert_sync_batchnorm in
98
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
99
+ """
100
+ bn_module = nn.modules.batchnorm
101
+ bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
102
+ res = module
103
+ if isinstance(module, bn_module):
104
+ res = cls(module.num_features)
105
+ if module.affine:
106
+ res.weight.data = module.weight.data.clone().detach()
107
+ res.bias.data = module.bias.data.clone().detach()
108
+ res.running_mean.data = module.running_mean.data
109
+ res.running_var.data = module.running_var.data
110
+ res.eps = module.eps
111
+ else:
112
+ for name, child in module.named_children():
113
+ new_child = cls.convert_frozen_batchnorm(child)
114
+ if new_child is not child:
115
+ res.add_module(name, new_child)
116
+ return res
117
+
118
+
119
+ def get_norm(norm, out_channels):
120
+ """
121
+ Args:
122
+ norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
123
+ or a callable that takes a channel number and returns
124
+ the normalization layer as a nn.Module.
125
+
126
+ Returns:
127
+ nn.Module or None: the normalization layer
128
+ """
129
+ if norm is None:
130
+ return None
131
+ if isinstance(norm, str):
132
+ if len(norm) == 0:
133
+ return None
134
+ norm = {
135
+ "BN": BatchNorm2d,
136
+ # Fixed in https://github.com/pytorch/pytorch/pull/36382
137
+ "SyncBN": nn.SyncBatchNorm,
138
+ "FrozenBN": FrozenBatchNorm2d,
139
+ "GN": lambda channels: nn.GroupNorm(32, channels),
140
+ # for debugging:
141
+ "nnSyncBN": nn.SyncBatchNorm,
142
+ "LN": lambda channels: LayerNorm(channels)
143
+ }[norm]
144
+ return norm(out_channels)
145
+
146
+
147
+ class CycleBatchNormList(nn.ModuleList):
148
+ """
149
+ Implement domain-specific BatchNorm by cycling.
150
+
151
+ When a BatchNorm layer is used for multiple input domains or input
152
+ features, it might need to maintain a separate test-time statistics
153
+ for each domain. See Sec 5.2 in :paper:`rethinking-batchnorm`.
154
+
155
+ This module implements it by using N separate BN layers
156
+ and it cycles through them every time a forward() is called.
157
+
158
+ NOTE: The caller of this module MUST guarantee to always call
159
+ this module by multiple of N times. Otherwise its test-time statistics
160
+ will be incorrect.
161
+ """
162
+
163
+ def __init__(self, length: int, bn_class=nn.BatchNorm2d, **kwargs):
164
+ """
165
+ Args:
166
+ length: number of BatchNorm layers to cycle.
167
+ bn_class: the BatchNorm class to use
168
+ kwargs: arguments of the BatchNorm class, such as num_features.
169
+ """
170
+ self._affine = kwargs.pop("affine", True)
171
+ super().__init__([bn_class(**kwargs, affine=False) for k in range(length)])
172
+ if self._affine:
173
+ # shared affine, domain-specific BN
174
+ channels = self[0].num_features
175
+ self.weight = nn.Parameter(torch.ones(channels))
176
+ self.bias = nn.Parameter(torch.zeros(channels))
177
+ self._pos = 0
178
+
179
+ def forward(self, x):
180
+ ret = self[self._pos](x)
181
+ self._pos = (self._pos + 1) % len(self)
182
+
183
+ if self._affine:
184
+ w = self.weight.reshape(1, -1, 1, 1)
185
+ b = self.bias.reshape(1, -1, 1, 1)
186
+ return ret * w + b
187
+ else:
188
+ return ret
189
+
190
+ def extra_repr(self):
191
+ return f"affine={self._affine}"
192
+
193
+
194
+ class LayerNorm(nn.Module):
195
+ """
196
+ A LayerNorm variant, popularized by Transformers, that performs point-wise mean and
197
+ variance normalization over the channel dimension for inputs that have shape
198
+ (batch_size, channels, height, width).
199
+ https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950
200
+ """
201
+
202
+ def __init__(self, normalized_shape, eps=1e-6):
203
+ super().__init__()
204
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
205
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
206
+ self.eps = eps
207
+ self.normalized_shape = (normalized_shape,)
208
+
209
+ def forward(self, x):
210
+ u = x.mean(1, keepdim=True)
211
+ s = (x - u).pow(2).mean(1, keepdim=True)
212
+ x = (x - u) / torch.sqrt(s + self.eps)
213
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
214
+ return x
models/backbones/eva02/blocks.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ import fvcore.nn.weight_init as weight_init
5
+ from torch import nn
6
+
7
+ from .batch_norm import FrozenBatchNorm2d, get_norm
8
+ from .wrappers import Conv2d
9
+
10
+
11
+ """
12
+ CNN building blocks.
13
+ """
14
+
15
+
16
+ class CNNBlockBase(nn.Module):
17
+ """
18
+ A CNN block is assumed to have input channels, output channels and a stride.
19
+ The input and output of `forward()` method must be NCHW tensors.
20
+ The method can perform arbitrary computation but must match the given
21
+ channels and stride specification.
22
+
23
+ Attribute:
24
+ in_channels (int):
25
+ out_channels (int):
26
+ stride (int):
27
+ """
28
+
29
+ def __init__(self, in_channels, out_channels, stride):
30
+ """
31
+ The `__init__` method of any subclass should also contain these arguments.
32
+
33
+ Args:
34
+ in_channels (int):
35
+ out_channels (int):
36
+ stride (int):
37
+ """
38
+ super().__init__()
39
+ self.in_channels = in_channels
40
+ self.out_channels = out_channels
41
+ self.stride = stride
42
+
43
+ def freeze(self):
44
+ """
45
+ Make this block not trainable.
46
+ This method sets all parameters to `requires_grad=False`,
47
+ and convert all BatchNorm layers to FrozenBatchNorm
48
+
49
+ Returns:
50
+ the block itself
51
+ """
52
+ for p in self.parameters():
53
+ p.requires_grad = False
54
+ FrozenBatchNorm2d.convert_frozen_batchnorm(self)
55
+ return self
56
+
57
+
58
+ class DepthwiseSeparableConv2d(nn.Module):
59
+ """
60
+ A kxk depthwise convolution + a 1x1 convolution.
61
+
62
+ In :paper:`xception`, norm & activation are applied on the second conv.
63
+ :paper:`mobilenet` uses norm & activation on both convs.
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ in_channels,
69
+ out_channels,
70
+ kernel_size=3,
71
+ padding=1,
72
+ dilation=1,
73
+ *,
74
+ norm1=None,
75
+ activation1=None,
76
+ norm2=None,
77
+ activation2=None,
78
+ ):
79
+ """
80
+ Args:
81
+ norm1, norm2 (str or callable): normalization for the two conv layers.
82
+ activation1, activation2 (callable(Tensor) -> Tensor): activation
83
+ function for the two conv layers.
84
+ """
85
+ super().__init__()
86
+ self.depthwise = Conv2d(
87
+ in_channels,
88
+ in_channels,
89
+ kernel_size=kernel_size,
90
+ padding=padding,
91
+ dilation=dilation,
92
+ groups=in_channels,
93
+ bias=not norm1,
94
+ norm=get_norm(norm1, in_channels),
95
+ activation=activation1,
96
+ )
97
+ self.pointwise = Conv2d(
98
+ in_channels,
99
+ out_channels,
100
+ kernel_size=1,
101
+ bias=not norm2,
102
+ norm=get_norm(norm2, out_channels),
103
+ activation=activation2,
104
+ )
105
+
106
+ # default initialization
107
+ weight_init.c2_msra_fill(self.depthwise)
108
+ weight_init.c2_msra_fill(self.pointwise)
109
+
110
+ def forward(self, x):
111
+ return self.pointwise(self.depthwise(x))
models/backbones/eva02/drop.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
2
+ import torch.nn as nn
3
+
4
+
5
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
6
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
7
+
8
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
9
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
10
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
11
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
12
+ 'survival rate' as the argument.
13
+
14
+ """
15
+ if drop_prob == 0. or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0 and scale_by_keep:
21
+ random_tensor.div_(keep_prob)
22
+ return x * random_tensor
23
+
24
+
25
+ class DropPath(nn.Module):
26
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
27
+ """
28
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
29
+ super(DropPath, self).__init__()
30
+ self.drop_prob = drop_prob
31
+ self.scale_by_keep = scale_by_keep
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
35
+
36
+ def extra_repr(self):
37
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
models/backbones/eva02/fpn.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import fvcore.nn.weight_init as weight_init
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+
6
+
7
+ def _assert_strides_are_log2_contiguous(strides):
8
+ """
9
+ Assert that each stride is 2x times its preceding stride, i.e. "contiguous in log2".
10
+ """
11
+ for i, stride in enumerate(strides[1:], 1):
12
+ assert stride == 2 * strides[i - 1], "Strides {} {} are not log2 contiguous".format(
13
+ stride, strides[i - 1]
14
+ )
15
+
16
+
17
+ class LastLevelMaxPool(nn.Module):
18
+ """
19
+ This module is used in the original FPN to generate a downsampled
20
+ P6 feature from P5.
21
+ """
22
+
23
+ def __init__(self):
24
+ super().__init__()
25
+ self.num_levels = 1
26
+ self.in_feature = "p5"
27
+
28
+ def forward(self, x):
29
+ return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)]
30
+
31
+
32
+ class LastLevelP6P7(nn.Module):
33
+ """
34
+ This module is used in RetinaNet to generate extra layers, P6 and P7 from
35
+ C5 feature.
36
+ """
37
+
38
+ def __init__(self, in_channels, out_channels, in_feature="res5"):
39
+ super().__init__()
40
+ self.num_levels = 2
41
+ self.in_feature = in_feature
42
+ self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
43
+ self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
44
+ for module in [self.p6, self.p7]:
45
+ weight_init.c2_xavier_fill(module)
46
+
47
+ def forward(self, c5):
48
+ p6 = self.p6(c5)
49
+ p7 = self.p7(F.relu(p6))
50
+ return [p6, p7]
models/backbones/eva02/main.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ import torch.nn as nn
4
+ from mmcv.runner.checkpoint import load_state_dict
5
+ from mmdet.models.builder import BACKBONES
6
+ from .vit import ViT, SimpleFeaturePyramid, partial
7
+ from .fpn import LastLevelMaxPool
8
+
9
+
10
+ @BACKBONES.register_module()
11
+ class EVA02(nn.Module):
12
+ def __init__(
13
+ self,
14
+ # args for ViT
15
+ img_size=1024,
16
+ real_img_size=(256, 704),
17
+ patch_size=16,
18
+ in_chans=3,
19
+ embed_dim=768,
20
+ depth=12,
21
+ num_heads=12,
22
+ mlp_ratio=4*2/3,
23
+ qkv_bias=True,
24
+ drop_path_rate=0.0,
25
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
26
+ use_abs_pos=True,
27
+ pt_hw_seq_len=16,
28
+ intp_freq=True,
29
+ window_size=0,
30
+ window_block_indexes=(),
31
+ residual_block_indexes=(),
32
+ use_act_checkpoint=False,
33
+ pretrain_img_size=224,
34
+ pretrain_use_cls_token=True,
35
+ out_feature="last_feat",
36
+ xattn=False,
37
+ frozen_blocks=-1,
38
+ # args for simple FPN
39
+ fpn_in_feature="last_feat",
40
+ fpn_out_channels=256,
41
+ fpn_scale_factors=(4.0, 2.0, 1.0, 0.5),
42
+ fpn_top_block=False,
43
+ fpn_norm="LN",
44
+ fpn_square_pad=0,
45
+ pretrained=None
46
+ ):
47
+ super().__init__()
48
+
49
+ self.backbone = SimpleFeaturePyramid(
50
+ ViT(
51
+ img_size=img_size,
52
+ real_img_size=real_img_size,
53
+ patch_size=patch_size,
54
+ in_chans=in_chans,
55
+ embed_dim=embed_dim,
56
+ depth=depth,
57
+ num_heads=num_heads,
58
+ mlp_ratio=mlp_ratio,
59
+ qkv_bias=qkv_bias,
60
+ drop_path_rate=drop_path_rate,
61
+ norm_layer=norm_layer,
62
+ use_abs_pos=use_abs_pos,
63
+ pt_hw_seq_len=pt_hw_seq_len,
64
+ intp_freq=intp_freq,
65
+ window_size=window_size,
66
+ window_block_indexes=window_block_indexes,
67
+ residual_block_indexes=residual_block_indexes,
68
+ use_act_checkpoint=use_act_checkpoint,
69
+ pretrain_img_size=pretrain_img_size,
70
+ pretrain_use_cls_token=pretrain_use_cls_token,
71
+ out_feature=out_feature,
72
+ xattn=xattn,
73
+ frozen_blocks=frozen_blocks,
74
+ ),
75
+ in_feature=fpn_in_feature,
76
+ out_channels=fpn_out_channels,
77
+ scale_factors=fpn_scale_factors,
78
+ top_block=LastLevelMaxPool() if fpn_top_block else None,
79
+ norm=fpn_norm,
80
+ square_pad=fpn_square_pad,
81
+ )
82
+ self.init_weights(pretrained)
83
+
84
+ def init_weights(self, pretrained=None):
85
+ if pretrained is None:
86
+ return
87
+ logging.info('Loading pretrained weights from %s' % pretrained)
88
+ state_dict = torch.load(pretrained)['model']
89
+ load_state_dict(self, state_dict, strict=False)
90
+
91
+ def forward(self, x):
92
+ outs = self.backbone(x)
93
+ return list(outs.values())
models/backbones/eva02/utils.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import math
3
+ import numpy as np
4
+ from scipy import interpolate
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ __all__ = [
10
+ "window_partition",
11
+ "window_unpartition",
12
+ "add_decomposed_rel_pos",
13
+ "get_abs_pos",
14
+ "PatchEmbed",
15
+ "VisionRotaryEmbeddingFast",
16
+ ]
17
+
18
+
19
+ def window_partition(x, window_size):
20
+ """
21
+ Partition into non-overlapping windows with padding if needed.
22
+ Args:
23
+ x (tensor): input tokens with [B, H, W, C].
24
+ window_size (int): window size.
25
+
26
+ Returns:
27
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
28
+ (Hp, Wp): padded height and width before partition
29
+ """
30
+ B, H, W, C = x.shape
31
+
32
+ pad_h = (window_size - H % window_size) % window_size
33
+ pad_w = (window_size - W % window_size) % window_size
34
+ if pad_h > 0 or pad_w > 0:
35
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
36
+ Hp, Wp = H + pad_h, W + pad_w
37
+
38
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
39
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
40
+ return windows, (Hp, Wp)
41
+
42
+
43
+ def window_unpartition(windows, window_size, pad_hw, hw):
44
+ """
45
+ Window unpartition into original sequences and removing padding.
46
+ Args:
47
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
48
+ window_size (int): window size.
49
+ pad_hw (Tuple): padded height and width (Hp, Wp).
50
+ hw (Tuple): original height and width (H, W) before padding.
51
+
52
+ Returns:
53
+ x: unpartitioned sequences with [B, H, W, C].
54
+ """
55
+ Hp, Wp = pad_hw
56
+ H, W = hw
57
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
58
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
59
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
60
+
61
+ if Hp > H or Wp > W:
62
+ x = x[:, :H, :W, :].contiguous()
63
+ return x
64
+
65
+
66
+ def get_rel_pos(q_size, k_size, rel_pos):
67
+ """
68
+ Get relative positional embeddings according to the relative positions of
69
+ query and key sizes.
70
+ Args:
71
+ q_size (int): size of query q.
72
+ k_size (int): size of key k.
73
+ rel_pos (Tensor): relative position embeddings (L, C).
74
+
75
+ Returns:
76
+ Extracted positional embeddings according to relative positions.
77
+ """
78
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
79
+ use_log_interpolation = True
80
+
81
+ # Interpolate rel pos if needed.
82
+ if rel_pos.shape[0] != max_rel_dist:
83
+ if not use_log_interpolation:
84
+ # Interpolate rel pos.
85
+ rel_pos_resized = F.interpolate(
86
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
87
+ size=max_rel_dist,
88
+ mode="linear",
89
+ )
90
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
91
+ else:
92
+ src_size = rel_pos.shape[0]
93
+ dst_size = max_rel_dist
94
+
95
+ # q = 1.13492
96
+ q = 1.0903078
97
+ dis = []
98
+
99
+ cur = 1
100
+ for i in range(src_size // 2):
101
+ dis.append(cur)
102
+ cur += q ** (i + 1)
103
+
104
+ r_ids = [-_ for _ in reversed(dis)]
105
+ x = r_ids + [0] + dis
106
+ t = dst_size // 2.0
107
+ dx = np.arange(-t, t + 0.1, 1.0)
108
+ # print("x = %s" % str(x))
109
+ # print("dx = %s" % str(dx))
110
+ all_rel_pos_bias = []
111
+ for i in range(rel_pos.shape[1]):
112
+ z = rel_pos[:, i].view(src_size).cpu().float().numpy()
113
+ f = interpolate.interp1d(x, z, kind='cubic', fill_value="extrapolate")
114
+ all_rel_pos_bias.append(
115
+ torch.Tensor(f(dx)).contiguous().view(-1, 1).to(rel_pos.device))
116
+ rel_pos_resized = torch.cat(all_rel_pos_bias, dim=-1)
117
+ else:
118
+ rel_pos_resized = rel_pos
119
+
120
+ # Scale the coords with short length if shapes for q and k are different.
121
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
122
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
123
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
124
+
125
+ return rel_pos_resized[relative_coords.long()]
126
+
127
+
128
+ def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
129
+ """
130
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
131
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
132
+ Args:
133
+ attn (Tensor): attention map.
134
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
135
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
136
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
137
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
138
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
139
+
140
+ Returns:
141
+ attn (Tensor): attention map with added relative positional embeddings.
142
+ """
143
+ q_h, q_w = q_size
144
+ k_h, k_w = k_size
145
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
146
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
147
+
148
+ B, _, dim = q.shape
149
+ r_q = q.reshape(B, q_h, q_w, dim)
150
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
151
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
152
+
153
+ attn = (
154
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
155
+ ).view(B, q_h * q_w, k_h * k_w)
156
+
157
+ return attn
158
+
159
+
160
+ def get_abs_pos(abs_pos, has_cls_token, hw):
161
+ """
162
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
163
+ dimension for the original embeddings.
164
+ Args:
165
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
166
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
167
+ hw (Tuple): size of input image tokens.
168
+
169
+ Returns:
170
+ Absolute positional embeddings after processing with shape (1, H, W, C)
171
+ """
172
+ h, w = hw
173
+ if has_cls_token:
174
+ abs_pos = abs_pos[:, 1:]
175
+ xy_num = abs_pos.shape[1]
176
+ size = int(math.sqrt(xy_num))
177
+ assert size * size == xy_num
178
+
179
+ if size != h or size != w:
180
+ new_abs_pos = F.interpolate(
181
+ abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
182
+ size=(h, w),
183
+ mode="bicubic",
184
+ align_corners=False,
185
+ )
186
+
187
+ return new_abs_pos.permute(0, 2, 3, 1)
188
+ else:
189
+ return abs_pos.reshape(1, h, w, -1)
190
+
191
+
192
+ class PatchEmbed(nn.Module):
193
+ """
194
+ Image to Patch Embedding.
195
+ """
196
+
197
+ def __init__(
198
+ self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
199
+ ):
200
+ """
201
+ Args:
202
+ kernel_size (Tuple): kernel size of the projection layer.
203
+ stride (Tuple): stride of the projection layer.
204
+ padding (Tuple): padding size of the projection layer.
205
+ in_chans (int): Number of input image channels.
206
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
207
+ """
208
+ super().__init__()
209
+
210
+ self.proj = nn.Conv2d(
211
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
212
+ )
213
+
214
+ def forward(self, x):
215
+ x = self.proj(x)
216
+ # B C H W -> B H W C
217
+ x = x.permute(0, 2, 3, 1)
218
+ return x
219
+
220
+
221
+
222
+
223
+ from math import pi
224
+
225
+ import torch
226
+ from torch import nn
227
+
228
+ from einops import rearrange, repeat
229
+
230
+
231
+ def broadcat(tensors, dim = -1):
232
+ num_tensors = len(tensors)
233
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
234
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
235
+ shape_len = list(shape_lens)[0]
236
+ dim = (dim + shape_len) if dim < 0 else dim
237
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
238
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
239
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
240
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
241
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
242
+ expanded_dims.insert(dim, (dim, dims[dim]))
243
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
244
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
245
+ return torch.cat(tensors, dim = dim)
246
+
247
+
248
+ def rotate_half(x):
249
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
250
+ x1, x2 = x.unbind(dim = -1)
251
+ x = torch.stack((-x2, x1), dim = -1)
252
+ return rearrange(x, '... d r -> ... (d r)')
253
+
254
+
255
+ class VisionRotaryEmbedding(nn.Module):
256
+ def __init__(
257
+ self,
258
+ dim,
259
+ pt_seq_len,
260
+ ft_seq_len=None,
261
+ custom_freqs = None,
262
+ freqs_for = 'lang',
263
+ theta = 10000,
264
+ max_freq = 10,
265
+ num_freqs = 1,
266
+ ):
267
+ super().__init__()
268
+ if custom_freqs:
269
+ freqs = custom_freqs
270
+ elif freqs_for == 'lang':
271
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
272
+ elif freqs_for == 'pixel':
273
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
274
+ elif freqs_for == 'constant':
275
+ freqs = torch.ones(num_freqs).float()
276
+ else:
277
+ raise ValueError(f'unknown modality {freqs_for}')
278
+
279
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
280
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
281
+
282
+ freqs_h = torch.einsum('..., f -> ... f', t, freqs)
283
+ freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
284
+
285
+ freqs_w = torch.einsum('..., f -> ... f', t, freqs)
286
+ freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
287
+
288
+ freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
289
+
290
+ self.register_buffer("freqs_cos", freqs.cos())
291
+ self.register_buffer("freqs_sin", freqs.sin())
292
+
293
+ print('======== shape of rope freq', self.freqs_cos.shape, '========')
294
+
295
+ def forward(self, t, start_index = 0):
296
+ rot_dim = self.freqs_cos.shape[-1]
297
+ end_index = start_index + rot_dim
298
+ assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
299
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
300
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
301
+ return torch.cat((t_left, t, t_right), dim = -1)
302
+
303
+
304
+ class VisionRotaryEmbeddingFast(nn.Module):
305
+ def __init__(
306
+ self,
307
+ dim,
308
+ pt_seq_len=16,
309
+ ft_seq_len=None,
310
+ custom_freqs = None,
311
+ freqs_for = 'lang',
312
+ theta = 10000,
313
+ max_freq = 10,
314
+ num_freqs = 1,
315
+ real_img_size = None
316
+ ):
317
+ super().__init__()
318
+ if custom_freqs:
319
+ freqs = custom_freqs
320
+ elif freqs_for == 'lang':
321
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
322
+ elif freqs_for == 'pixel':
323
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
324
+ elif freqs_for == 'constant':
325
+ freqs = torch.ones(num_freqs).float()
326
+ else:
327
+ raise ValueError(f'unknown modality {freqs_for}')
328
+
329
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
330
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
331
+
332
+ freqs = torch.einsum('..., f -> ... f', t, freqs)
333
+ freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
334
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
335
+
336
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
337
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
338
+
339
+ if real_img_size is not None:
340
+ new_freqs_cos = F.interpolate(
341
+ freqs_cos.reshape(1, ft_seq_len, ft_seq_len, -1).permute(0, 3, 1, 2),
342
+ size=real_img_size,
343
+ mode="bicubic",
344
+ align_corners=False,
345
+ ).permute(0, 2, 3, 1)
346
+
347
+ new_freqs_sin = F.interpolate(
348
+ freqs_sin.reshape(1, ft_seq_len, ft_seq_len, -1).permute(0, 3, 1, 2),
349
+ size=real_img_size,
350
+ mode="bicubic",
351
+ align_corners=False,
352
+ ).permute(0, 2, 3, 1)
353
+
354
+ self.register_buffer("freqs_cos", new_freqs_cos.view(-1, freqs.shape[-1]))
355
+ self.register_buffer("freqs_sin", new_freqs_sin.view(-1, freqs.shape[-1]))
356
+ else:
357
+ self.register_buffer("freqs_cos", freqs_cos)
358
+ self.register_buffer("freqs_sin", freqs_sin)
359
+
360
+ def forward(self, t):
361
+ return t * self.freqs_cos[:, None, :] + rotate_half(t) * self.freqs_sin[:, None, :]
models/backbones/eva02/vit.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from functools import partial
4
+
5
+ import fvcore.nn.weight_init as weight_init
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.utils.checkpoint as cp
10
+
11
+ from .wrappers import Conv2d
12
+ from .batch_norm import get_norm
13
+ from .blocks import CNNBlockBase
14
+ from .fpn import _assert_strides_are_log2_contiguous
15
+
16
+ from .backbone import Backbone
17
+ from .utils import (
18
+ PatchEmbed,
19
+ add_decomposed_rel_pos,
20
+ get_abs_pos,
21
+ window_partition,
22
+ window_unpartition,
23
+ VisionRotaryEmbeddingFast,
24
+ )
25
+
26
+ try:
27
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
28
+ except ImportError:
29
+ flash_attn_func = None
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class SwiGLU(nn.Module):
35
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
36
+ norm_layer=nn.LayerNorm, subln=False
37
+ ):
38
+ super().__init__()
39
+ out_features = out_features or in_features
40
+ hidden_features = hidden_features or in_features
41
+
42
+ self.w1 = nn.Linear(in_features, hidden_features)
43
+ self.w2 = nn.Linear(in_features, hidden_features)
44
+
45
+ self.act = act_layer()
46
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
47
+ self.w3 = nn.Linear(hidden_features, out_features)
48
+
49
+ self.drop = nn.Dropout(drop)
50
+
51
+ def forward(self, x):
52
+ x1 = self.w1(x)
53
+ x2 = self.w2(x)
54
+ hidden = self.act(x1) * x2
55
+ x = self.ffn_ln(hidden)
56
+ x = self.w3(x)
57
+ x = self.drop(x)
58
+ return x
59
+
60
+
61
+ class Attention(nn.Module):
62
+ def __init__(
63
+ self,
64
+ dim,
65
+ num_heads=8,
66
+ qkv_bias=True,
67
+ qk_scale=None,
68
+ attn_head_dim=None,
69
+ rope=None,
70
+ xattn=True,
71
+ ):
72
+ super().__init__()
73
+ self.num_heads = num_heads
74
+ head_dim = dim // num_heads
75
+ if attn_head_dim is not None:
76
+ head_dim = attn_head_dim
77
+ all_head_dim = head_dim * self.num_heads
78
+ self.scale = qk_scale or head_dim ** -0.5
79
+
80
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
81
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
82
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
83
+
84
+ if qkv_bias:
85
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
86
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
87
+ else:
88
+ self.q_bias = None
89
+ self.v_bias = None
90
+
91
+ self.rope = rope
92
+ self.xattn = xattn
93
+ self.proj = nn.Linear(all_head_dim, dim)
94
+
95
+ def forward(self, x):
96
+ B, H, W, C = x.shape
97
+ x = x.view(B, -1, C)
98
+ N = H * W
99
+
100
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
101
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
102
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
103
+
104
+ q = q.reshape(B, N, self.num_heads, -1)
105
+ k = k.reshape(B, N, self.num_heads, -1)
106
+ v = v.reshape(B, N, self.num_heads, -1)
107
+
108
+ ## rope
109
+ q = self.rope(q).type_as(v)
110
+ k = self.rope(k).type_as(v)
111
+
112
+ if self.xattn:
113
+ x = flash_attn_func(q, k, v).reshape(B, N, -1)
114
+ else:
115
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C
116
+ k = k.permute(0, 2, 1, 3) # B, num_heads, N, C
117
+ v = v.permute(0, 2, 1, 3) # B, num_heads, N, C
118
+ x = F.scaled_dot_product_attention(q, k, v)
119
+ x = x.transpose(1, 2).reshape(B, N, -1)
120
+
121
+ x = self.proj(x)
122
+ x = x.view(B, H, W, C)
123
+
124
+ return x
125
+
126
+
127
+ class ResBottleneckBlock(CNNBlockBase):
128
+ """
129
+ The standard bottleneck residual block without the last activation layer.
130
+ It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
131
+ """
132
+
133
+ def __init__(
134
+ self,
135
+ in_channels,
136
+ out_channels,
137
+ bottleneck_channels,
138
+ norm="LN",
139
+ act_layer=nn.GELU,
140
+ ):
141
+ """
142
+ Args:
143
+ in_channels (int): Number of input channels.
144
+ out_channels (int): Number of output channels.
145
+ bottleneck_channels (int): number of output channels for the 3x3
146
+ "bottleneck" conv layers.
147
+ norm (str or callable): normalization for all conv layers.
148
+ See :func:`layers.get_norm` for supported format.
149
+ act_layer (callable): activation for all conv layers.
150
+ """
151
+ super().__init__(in_channels, out_channels, 1)
152
+
153
+ self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
154
+ self.norm1 = get_norm(norm, bottleneck_channels)
155
+ self.act1 = act_layer()
156
+
157
+ self.conv2 = Conv2d(
158
+ bottleneck_channels,
159
+ bottleneck_channels,
160
+ 3,
161
+ padding=1,
162
+ bias=False,
163
+ )
164
+ self.norm2 = get_norm(norm, bottleneck_channels)
165
+ self.act2 = act_layer()
166
+
167
+ self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
168
+ self.norm3 = get_norm(norm, out_channels)
169
+
170
+ for layer in [self.conv1, self.conv2, self.conv3]:
171
+ weight_init.c2_msra_fill(layer)
172
+ for layer in [self.norm1, self.norm2]:
173
+ layer.weight.data.fill_(1.0)
174
+ layer.bias.data.zero_()
175
+ # zero init last norm layer.
176
+ self.norm3.weight.data.zero_()
177
+ self.norm3.bias.data.zero_()
178
+
179
+ def forward(self, x):
180
+ out = x
181
+ for layer in self.children():
182
+ out = layer(out)
183
+
184
+ out = x + out
185
+ return out
186
+
187
+
188
+ class Block(nn.Module):
189
+ """Transformer blocks with support of window attention and residual propagation blocks"""
190
+
191
+ def __init__(
192
+ self,
193
+ dim,
194
+ num_heads,
195
+ mlp_ratio=4*2/3,
196
+ qkv_bias=True,
197
+ drop_path=0.0,
198
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
199
+ window_size=0,
200
+ use_residual_block=False,
201
+ rope=None,
202
+ xattn=True,
203
+ use_act_checkpoint=True,
204
+ ):
205
+ """
206
+ Args:
207
+ dim (int): Number of input channels.
208
+ num_heads (int): Number of attention heads in each ViT block.
209
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
210
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
211
+ drop_path (float): Stochastic depth rate.
212
+ norm_layer (nn.Module): Normalization layer.
213
+ act_layer (nn.Module): Activation layer.
214
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
215
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
216
+ window_size (int): Window size for window attention blocks. If it equals 0, then not
217
+ use window attention.
218
+ use_residual_block (bool): If True, use a residual block after the MLP block.
219
+ input_size (int or None): Input resolution for calculating the relative positional
220
+ parameter size.
221
+ """
222
+ super().__init__()
223
+ self.norm1 = norm_layer(dim)
224
+ self.attn = Attention(
225
+ dim,
226
+ num_heads=num_heads,
227
+ qkv_bias=qkv_bias,
228
+ rope=rope,
229
+ xattn=xattn,
230
+ )
231
+
232
+ from .drop import DropPath
233
+
234
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
235
+ self.norm2 = norm_layer(dim)
236
+ self.mlp = SwiGLU(
237
+ in_features=dim,
238
+ hidden_features=int(dim * mlp_ratio),
239
+ subln=True,
240
+ norm_layer=norm_layer,
241
+ )
242
+
243
+ self.window_size = window_size
244
+
245
+ self.use_residual_block = use_residual_block
246
+ if use_residual_block:
247
+ # Use a residual block with bottleneck channel as dim // 2
248
+ self.residual = ResBottleneckBlock(
249
+ in_channels=dim,
250
+ out_channels=dim,
251
+ bottleneck_channels=dim // 2,
252
+ norm="LN",
253
+ )
254
+
255
+ self.use_act_checkpoint = use_act_checkpoint
256
+
257
+ def inner_forward(self, x):
258
+ shortcut = x
259
+ x = self.norm1(x)
260
+
261
+ # Window partition
262
+ if self.window_size > 0:
263
+ H, W = x.shape[1], x.shape[2]
264
+ x, pad_hw = window_partition(x, self.window_size)
265
+
266
+ x = self.attn(x)
267
+
268
+ # Reverse window partition
269
+ if self.window_size > 0:
270
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
271
+
272
+ x = shortcut + self.drop_path(x)
273
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
274
+
275
+ if self.use_residual_block:
276
+ x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
277
+
278
+ return x
279
+
280
+ def forward(self, x):
281
+ if self.training and x.requires_grad and self.use_act_checkpoint:
282
+ return cp.checkpoint(self.inner_forward, x)
283
+ else:
284
+ return self.inner_forward(x)
285
+
286
+
287
+ class ViT(Backbone):
288
+ """
289
+ This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
290
+ "Exploring Plain Vision Transformer Backbones for Object Detection",
291
+ https://arxiv.org/abs/2203.16527
292
+ """
293
+
294
+ def __init__(
295
+ self,
296
+ img_size=1024,
297
+ real_img_size=(256, 704),
298
+ patch_size=16,
299
+ in_chans=3,
300
+ embed_dim=768,
301
+ depth=12,
302
+ num_heads=12,
303
+ mlp_ratio=4*2/3,
304
+ qkv_bias=True,
305
+ drop_path_rate=0.0,
306
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
307
+ use_abs_pos=True,
308
+ pt_hw_seq_len=16,
309
+ intp_freq=True,
310
+ window_size=0,
311
+ window_block_indexes=(),
312
+ residual_block_indexes=(),
313
+ use_act_checkpoint=False,
314
+ pretrain_img_size=224,
315
+ pretrain_use_cls_token=True,
316
+ out_feature="last_feat",
317
+ xattn=True,
318
+ frozen_blocks=-1,
319
+ ):
320
+ """
321
+ Args:
322
+ img_size (int): Input image size.
323
+ patch_size (int): Patch size.
324
+ in_chans (int): Number of input image channels.
325
+ embed_dim (int): Patch embedding dimension.
326
+ depth (int): Depth of ViT.
327
+ num_heads (int): Number of attention heads in each ViT block.
328
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
329
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
330
+ drop_path_rate (float): Stochastic depth rate.
331
+ norm_layer (nn.Module): Normalization layer.
332
+ act_layer (nn.Module): Activation layer.
333
+ use_abs_pos (bool): If True, use absolute positional embeddings.
334
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
335
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
336
+ window_size (int): Window size for window attention blocks.
337
+ window_block_indexes (list): Indexes for blocks using window attention.
338
+ residual_block_indexes (list): Indexes for blocks using conv propagation.
339
+ use_act_checkpoint (bool): If True, use activation checkpointing.
340
+ pretrain_img_size (int): input image size for pretraining models.
341
+ pretrain_use_cls_token (bool): If True, pretrainig models use class token.
342
+ out_feature (str): name of the feature from the last block.
343
+ """
344
+ super().__init__()
345
+ self.pretrain_use_cls_token = pretrain_use_cls_token
346
+ self.frozen_blocks = frozen_blocks
347
+
348
+ self.patch_embed = PatchEmbed(
349
+ kernel_size=(patch_size, patch_size),
350
+ stride=(patch_size, patch_size),
351
+ in_chans=in_chans,
352
+ embed_dim=embed_dim,
353
+ )
354
+
355
+ if use_abs_pos:
356
+ # Initialize absolute positional embedding with pretrain image size.
357
+ num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
358
+ num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
359
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
360
+ else:
361
+ self.pos_embed = None
362
+
363
+ half_head_dim = embed_dim // num_heads // 2
364
+ hw_seq_len = img_size // patch_size
365
+ real_img_size = (real_img_size[0] // patch_size, real_img_size[1] // patch_size)
366
+
367
+ self.rope_win = VisionRotaryEmbeddingFast(
368
+ dim=half_head_dim,
369
+ pt_seq_len=pt_hw_seq_len,
370
+ ft_seq_len=window_size if intp_freq else None,
371
+ )
372
+ self.rope_glb = VisionRotaryEmbeddingFast(
373
+ dim=half_head_dim,
374
+ pt_seq_len=pt_hw_seq_len,
375
+ ft_seq_len=hw_seq_len if intp_freq else None,
376
+ real_img_size=real_img_size
377
+ )
378
+
379
+ # stochastic depth decay rule
380
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
381
+
382
+ self.blocks = nn.ModuleList()
383
+ for i in range(depth):
384
+ block = Block(
385
+ dim=embed_dim,
386
+ num_heads=num_heads,
387
+ mlp_ratio=mlp_ratio,
388
+ qkv_bias=qkv_bias,
389
+ drop_path=dpr[i],
390
+ norm_layer=norm_layer,
391
+ window_size=window_size if i in window_block_indexes else 0,
392
+ use_residual_block=i in residual_block_indexes,
393
+ rope=self.rope_win if i in window_block_indexes else self.rope_glb,
394
+ xattn=xattn,
395
+ use_act_checkpoint=use_act_checkpoint
396
+ )
397
+ self.blocks.append(block)
398
+
399
+ self._out_feature_channels = {out_feature: embed_dim}
400
+ self._out_feature_strides = {out_feature: patch_size}
401
+ self._out_features = [out_feature]
402
+
403
+ if self.pos_embed is not None:
404
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
405
+
406
+ self.apply(self._init_weights)
407
+
408
+ def _init_weights(self, m):
409
+ if isinstance(m, nn.Linear):
410
+ nn.init.trunc_normal_(m.weight, std=0.02)
411
+ if isinstance(m, nn.Linear) and m.bias is not None:
412
+ nn.init.constant_(m.bias, 0)
413
+ elif isinstance(m, nn.LayerNorm):
414
+ nn.init.constant_(m.bias, 0)
415
+ nn.init.constant_(m.weight, 1.0)
416
+
417
+ def forward(self, x):
418
+ x = self.patch_embed(x)
419
+ if self.pos_embed is not None:
420
+ x = x + get_abs_pos(
421
+ self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
422
+ )
423
+
424
+ for blk in self.blocks:
425
+ x = blk(x)
426
+
427
+ outputs = {self._out_features[0]: x.permute(0, 3, 1, 2)}
428
+ return outputs
429
+
430
+ def _freeze_stages(self):
431
+ def freeze_module(m):
432
+ m.eval()
433
+ for param in m.parameters():
434
+ param.requires_grad = False
435
+
436
+ if self.frozen_blocks >= 0:
437
+ freeze_module(self.patch_embed)
438
+ self.pos_embed.requires_grad=False
439
+
440
+ for i in range(0, self.frozen_blocks):
441
+ freeze_module(self.blocks[i])
442
+
443
+ def train(self, mode=True):
444
+ super().train(mode)
445
+ self._freeze_stages()
446
+
447
+
448
+ class SimpleFeaturePyramid(Backbone):
449
+ """
450
+ This module implements SimpleFeaturePyramid in :paper:`vitdet`.
451
+ It creates pyramid features built on top of the input feature map.
452
+ """
453
+
454
+ def __init__(
455
+ self,
456
+ net,
457
+ in_feature,
458
+ out_channels,
459
+ scale_factors,
460
+ top_block=None,
461
+ norm="LN",
462
+ square_pad=0,
463
+ ):
464
+ """
465
+ Args:
466
+ net (Backbone): module representing the subnetwork backbone.
467
+ Must be a subclass of :class:`Backbone`.
468
+ in_feature (str): names of the input feature maps coming
469
+ from the net.
470
+ out_channels (int): number of channels in the output feature maps.
471
+ scale_factors (list[float]): list of scaling factors to upsample or downsample
472
+ the input features for creating pyramid features.
473
+ top_block (nn.Module or None): if provided, an extra operation will
474
+ be performed on the output of the last (smallest resolution)
475
+ pyramid output, and the result will extend the result list. The top_block
476
+ further downsamples the feature map. It must have an attribute
477
+ "num_levels", meaning the number of extra pyramid levels added by
478
+ this block, and "in_feature", which is a string representing
479
+ its input feature (e.g., p5).
480
+ norm (str): the normalization to use.
481
+ square_pad (int): If > 0, require input images to be padded to specific square size.
482
+ """
483
+ super(SimpleFeaturePyramid, self).__init__()
484
+ assert isinstance(net, Backbone)
485
+
486
+ self.scale_factors = scale_factors
487
+
488
+ input_shapes = net.output_shape()
489
+ strides = [int(input_shapes[in_feature].stride / scale) for scale in scale_factors]
490
+ _assert_strides_are_log2_contiguous(strides)
491
+
492
+ dim = input_shapes[in_feature].channels
493
+ self.stages = []
494
+ use_bias = norm == ""
495
+ for idx, scale in enumerate(scale_factors):
496
+ out_dim = dim
497
+ if scale == 4.0:
498
+ layers = [
499
+ nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
500
+ get_norm(norm, dim // 2),
501
+ nn.GELU(),
502
+ nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
503
+ ]
504
+ out_dim = dim // 4
505
+ elif scale == 2.0:
506
+ layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)]
507
+ out_dim = dim // 2
508
+ elif scale == 1.0:
509
+ layers = []
510
+ elif scale == 0.5:
511
+ layers = [nn.MaxPool2d(kernel_size=2, stride=2)]
512
+ else:
513
+ raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
514
+
515
+ layers.extend(
516
+ [
517
+ Conv2d(
518
+ out_dim,
519
+ out_channels,
520
+ kernel_size=1,
521
+ bias=use_bias,
522
+ norm=get_norm(norm, out_channels),
523
+ ),
524
+ Conv2d(
525
+ out_channels,
526
+ out_channels,
527
+ kernel_size=3,
528
+ padding=1,
529
+ bias=use_bias,
530
+ norm=get_norm(norm, out_channels),
531
+ ),
532
+ ]
533
+ )
534
+ layers = nn.Sequential(*layers)
535
+
536
+ stage = int(math.log2(strides[idx]))
537
+ self.add_module(f"simfp_{stage}", layers)
538
+ self.stages.append(layers)
539
+
540
+ self.net = net
541
+ self.in_feature = in_feature
542
+ self.top_block = top_block
543
+ # Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"]
544
+ self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides}
545
+ # top block output feature maps.
546
+ if self.top_block is not None:
547
+ for s in range(stage, stage + self.top_block.num_levels):
548
+ self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1)
549
+
550
+ self._out_features = list(self._out_feature_strides.keys())
551
+ self._out_feature_channels = {k: out_channels for k in self._out_features}
552
+ self._size_divisibility = strides[-1]
553
+ self._square_pad = square_pad
554
+
555
+ @property
556
+ def padding_constraints(self):
557
+ return {
558
+ "size_divisiblity": self._size_divisibility,
559
+ "square_size": self._square_pad,
560
+ }
561
+
562
+ def forward(self, x):
563
+ """
564
+ Args:
565
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
566
+
567
+ Returns:
568
+ dict[str->Tensor]:
569
+ mapping from feature map name to pyramid feature map tensor
570
+ in high to low resolution order. Returned feature names follow the FPN
571
+ convention: "p<stage>", where stage has stride = 2 ** stage e.g.,
572
+ ["p2", "p3", ..., "p6"].
573
+ """
574
+ bottom_up_features = self.net(x)
575
+ features = bottom_up_features[self.in_feature]
576
+ results = []
577
+
578
+ for stage in self.stages:
579
+ results.append(stage(features))
580
+
581
+ if self.top_block is not None:
582
+ if self.top_block.in_feature in bottom_up_features:
583
+ top_block_in_feature = bottom_up_features[self.top_block.in_feature]
584
+ else:
585
+ top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)]
586
+ results.extend(self.top_block(top_block_in_feature))
587
+ assert len(self._out_features) == len(results)
588
+ return {f: res for f, res in zip(self._out_features, results)}
589
+
590
+
591
+ def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
592
+ """
593
+ Calculate lr decay rate for different ViT blocks.
594
+ Args:
595
+ name (string): parameter name.
596
+ lr_decay_rate (float): base lr decay rate.
597
+ num_layers (int): number of ViT blocks.
598
+
599
+ Returns:
600
+ lr decay rate for the given parameter.
601
+ """
602
+ layer_id = num_layers + 1
603
+ if name.startswith("backbone"):
604
+ if ".pos_embed" in name or ".patch_embed" in name:
605
+ layer_id = 0
606
+ elif ".blocks." in name and ".residual." not in name:
607
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
608
+
609
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
models/backbones/eva02/wrappers.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ """
3
+ Wrappers around on some nn functions, mainly to support empty tensors.
4
+
5
+ Ideally, add support directly in PyTorch to empty tensors in those functions.
6
+
7
+ These can be removed once https://github.com/pytorch/pytorch/issues/12013
8
+ is implemented
9
+ """
10
+
11
+ import warnings
12
+ from typing import List, Optional
13
+ import torch
14
+ from torch.nn import functional as F
15
+
16
+
17
+ def shapes_to_tensor(x: List[int], device: Optional[torch.device] = None) -> torch.Tensor:
18
+ """
19
+ Turn a list of integer scalars or integer Tensor scalars into a vector,
20
+ in a way that's both traceable and scriptable.
21
+
22
+ In tracing, `x` should be a list of scalar Tensor, so the output can trace to the inputs.
23
+ In scripting or eager, `x` should be a list of int.
24
+ """
25
+ if torch.jit.is_scripting():
26
+ return torch.as_tensor(x, device=device)
27
+ if torch.jit.is_tracing():
28
+ assert all(
29
+ [isinstance(t, torch.Tensor) for t in x]
30
+ ), "Shape should be tensor during tracing!"
31
+ # as_tensor should not be used in tracing because it records a constant
32
+ ret = torch.stack(x)
33
+ if ret.device != device: # avoid recording a hard-coded device if not necessary
34
+ ret = ret.to(device=device)
35
+ return ret
36
+ return torch.as_tensor(x, device=device)
37
+
38
+
39
+ def cat(tensors: List[torch.Tensor], dim: int = 0):
40
+ """
41
+ Efficient version of torch.cat that avoids a copy if there is only a single element in a list
42
+ """
43
+ assert isinstance(tensors, (list, tuple))
44
+ if len(tensors) == 1:
45
+ return tensors[0]
46
+ return torch.cat(tensors, dim)
47
+
48
+
49
+ def empty_input_loss_func_wrapper(loss_func):
50
+ def wrapped_loss_func(input, target, *, reduction="mean", **kwargs):
51
+ """
52
+ Same as `loss_func`, but returns 0 (instead of nan) for empty inputs.
53
+ """
54
+ if target.numel() == 0 and reduction == "mean":
55
+ return input.sum() * 0.0 # connect the gradient
56
+ return loss_func(input, target, reduction=reduction, **kwargs)
57
+
58
+ return wrapped_loss_func
59
+
60
+
61
+ cross_entropy = empty_input_loss_func_wrapper(F.cross_entropy)
62
+
63
+
64
+ class _NewEmptyTensorOp(torch.autograd.Function):
65
+ @staticmethod
66
+ def forward(ctx, x, new_shape):
67
+ ctx.shape = x.shape
68
+ return x.new_empty(new_shape)
69
+
70
+ @staticmethod
71
+ def backward(ctx, grad):
72
+ shape = ctx.shape
73
+ return _NewEmptyTensorOp.apply(grad, shape), None
74
+
75
+
76
+ class Conv2d(torch.nn.Conv2d):
77
+ """
78
+ A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
79
+ """
80
+
81
+ def __init__(self, *args, **kwargs):
82
+ """
83
+ Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
84
+
85
+ Args:
86
+ norm (nn.Module, optional): a normalization layer
87
+ activation (callable(Tensor) -> Tensor): a callable activation function
88
+
89
+ It assumes that norm layer is used before activation.
90
+ """
91
+ norm = kwargs.pop("norm", None)
92
+ activation = kwargs.pop("activation", None)
93
+ super().__init__(*args, **kwargs)
94
+
95
+ self.norm = norm
96
+ self.activation = activation
97
+
98
+ def forward(self, x):
99
+ # torchscript does not support SyncBatchNorm yet
100
+ # https://github.com/pytorch/pytorch/issues/40507
101
+ # and we skip these codes in torchscript since:
102
+ # 1. currently we only support torchscript in evaluation mode
103
+ # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
104
+ # later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
105
+ if not torch.jit.is_scripting():
106
+ with warnings.catch_warnings(record=True):
107
+ if x.numel() == 0 and self.training:
108
+ # https://github.com/pytorch/pytorch/issues/12013
109
+ assert not isinstance(
110
+ self.norm, torch.nn.SyncBatchNorm
111
+ ), "SyncBatchNorm does not support empty inputs!"
112
+
113
+ x = F.conv2d(
114
+ x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
115
+ )
116
+ if self.norm is not None:
117
+ x = self.norm(x)
118
+ if self.activation is not None:
119
+ x = self.activation(x)
120
+ return x
121
+
122
+
123
+ ConvTranspose2d = torch.nn.ConvTranspose2d
124
+ BatchNorm2d = torch.nn.BatchNorm2d
125
+ interpolate = F.interpolate
126
+ Linear = torch.nn.Linear
127
+
128
+
129
+ def nonzero_tuple(x):
130
+ """
131
+ A 'as_tuple=True' version of torch.nonzero to support torchscript.
132
+ because of https://github.com/pytorch/pytorch/issues/38718
133
+ """
134
+ if torch.jit.is_scripting():
135
+ if x.dim() == 0:
136
+ return x.unsqueeze(0).nonzero().unbind(1)
137
+ return x.nonzero().unbind(1)
138
+ else:
139
+ return x.nonzero(as_tuple=True)
140
+
141
+
142
+ @torch.jit.script_if_tracing
143
+ def move_device_like(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor:
144
+ """
145
+ Tracing friendly way to cast tensor to another tensor's device. Device will be treated
146
+ as constant during tracing, scripting the casting process as whole can workaround this issue.
147
+ """
148
+ return src.to(dst.device)
models/sparsebev.py CHANGED
@@ -14,7 +14,7 @@ from .utils import GridMask, pad_multiple, GpuPhotoMetricDistortion
14
  class SparseBEV(MVXTwoStageDetector):
15
  def __init__(self,
16
  data_aug=None,
17
- stop_prev_grad=False,
18
  pts_voxel_layer=None,
19
  pts_voxel_encoder=None,
20
  pts_middle_encoder=None,
@@ -99,12 +99,12 @@ class SparseBEV(MVXTwoStageDetector):
99
  for img_meta in img_metas:
100
  img_meta.update(input_shape=input_shape)
101
 
102
- if self.training and self.stop_prev_grad:
103
  H, W = input_shape
104
  img = img.reshape(B, -1, 6, C, H, W)
105
 
106
- img_grad = img[:, :1]
107
- img_nograd = img[:, 1:]
108
 
109
  all_img_feats = [self.extract_img_feat(img_grad.reshape(-1, C, H, W))]
110
 
 
14
  class SparseBEV(MVXTwoStageDetector):
15
  def __init__(self,
16
  data_aug=None,
17
+ stop_prev_grad=0,
18
  pts_voxel_layer=None,
19
  pts_voxel_encoder=None,
20
  pts_middle_encoder=None,
 
99
  for img_meta in img_metas:
100
  img_meta.update(input_shape=input_shape)
101
 
102
+ if self.training and self.stop_prev_grad > 0:
103
  H, W = input_shape
104
  img = img.reshape(B, -1, 6, C, H, W)
105
 
106
+ img_grad = img[:, :self.stop_prev_grad]
107
+ img_nograd = img[:, self.stop_prev_grad:]
108
 
109
  all_img_feats = [self.extract_img_feat(img_grad.reshape(-1, C, H, W))]
110
 
val.py CHANGED
@@ -112,6 +112,7 @@ def main():
112
  logging.info('Creating model: %s' % cfgs.model.type)
113
  model = build_model(cfgs.model)
114
  model.cuda()
 
115
 
116
  if world_size > 1:
117
  model = MMDistributedDataParallel(model, [local_rank], broadcast_buffers=False)
 
112
  logging.info('Creating model: %s' % cfgs.model.type)
113
  model = build_model(cfgs.model)
114
  model.cuda()
115
+ model.fp16_enabled = True
116
 
117
  if world_size > 1:
118
  model = MMDistributedDataParallel(model, [local_rank], broadcast_buffers=False)