Upload model files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +6 -0
- .gitignore +9 -0
- configs/__init__.py +0 -0
- configs/models/default.yaml +29 -0
- configs/paths.py +7 -0
- configs/run/demo.yaml +12 -0
- datasets/__init__.py +0 -0
- datasets/agora.py +111 -0
- datasets/base.py +521 -0
- datasets/bedlam.py +72 -0
- datasets/common.py +34 -0
- datasets/multiple_datasets.py +49 -0
- demo/img0.png +3 -0
- demo/img1.jpeg +0 -0
- demo/img2.jpg +3 -0
- docs/fix_chumpy.md +44 -0
- engines/__init__.py +0 -0
- engines/engine.py +347 -0
- engines/funcs/__init__.py +0 -0
- engines/funcs/eval_funcs.py +362 -0
- engines/funcs/infer_funcs.py +86 -0
- figures/pipeline.png +3 -0
- figures/qualitative_results.png +3 -0
- figures/results.png +3 -0
- figures/results_3d.gif +3 -0
- main.py +52 -0
- models/__init__.py +16 -0
- models/criterion.py +449 -0
- models/decoder.py +388 -0
- models/dn_components.py +193 -0
- models/encoders/__init__.py +52 -0
- models/encoders/dinov2/layers/__init__.py +11 -0
- models/encoders/dinov2/layers/attention.py +89 -0
- models/encoders/dinov2/layers/block.py +260 -0
- models/encoders/dinov2/layers/dino_head.py +58 -0
- models/encoders/dinov2/layers/drop_path.py +34 -0
- models/encoders/dinov2/layers/layer_scale.py +27 -0
- models/encoders/dinov2/layers/mlp.py +40 -0
- models/encoders/dinov2/layers/patch_embed.py +88 -0
- models/encoders/dinov2/layers/swiglu_ffn.py +72 -0
- models/encoders/dinov2/models/__init__.py +43 -0
- models/encoders/dinov2/models/vision_transformer.py +542 -0
- models/human_models/__init__.py +1 -0
- models/human_models/smpl_models.py +69 -0
- models/matcher.py +159 -0
- models/position_encoding.py +155 -0
- models/sat_model.py +767 -0
- requirements.txt +13 -0
- utils/__init__.py +1 -0
- utils/box_ops.py +139 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
demo/img0.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
demo/img2.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
figures/pipeline.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
figures/qualitative_results.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
figures/results_3d.gif filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
figures/results.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
debug_datas.py
|
| 2 |
+
outputs/
|
| 3 |
+
weights/
|
| 4 |
+
results/
|
| 5 |
+
tmps/
|
| 6 |
+
**.out
|
| 7 |
+
**/__pycache__/
|
| 8 |
+
datasets_visualization/
|
| 9 |
+
demo_results/
|
configs/__init__.py
ADDED
|
File without changes
|
configs/models/default.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
input_size: 1288
|
| 2 |
+
encoder: 'vitb'
|
| 3 |
+
|
| 4 |
+
# decoder
|
| 5 |
+
hidden_dim: 768
|
| 6 |
+
nheads: 4
|
| 7 |
+
dec_layers: 6
|
| 8 |
+
dim_feedforward: 2048
|
| 9 |
+
dropout: 0.0
|
| 10 |
+
num_queries: 50
|
| 11 |
+
transformer_activation: "relu"
|
| 12 |
+
|
| 13 |
+
sat_cfg:
|
| 14 |
+
use_sat: True
|
| 15 |
+
share_patch_embed: False
|
| 16 |
+
preprocess_pos_embed: False
|
| 17 |
+
num_lvls: 3
|
| 18 |
+
lvl_embed: True
|
| 19 |
+
get_map_layer: 3
|
| 20 |
+
use_additional_blocks: True
|
| 21 |
+
conf_thresh: 0.3
|
| 22 |
+
scale_thresh: 0.5
|
| 23 |
+
|
| 24 |
+
dn_cfg:
|
| 25 |
+
use_dn: True
|
| 26 |
+
dn_number: 10
|
| 27 |
+
tgt_embed_type: "params"
|
| 28 |
+
box_noise_scale: 0.4
|
| 29 |
+
tgt_noise_scale: 0.2
|
configs/paths.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset_root = '../datasets'
|
| 2 |
+
|
| 3 |
+
smpl_model_path = './weights/smpl_data'
|
| 4 |
+
smpl_mean_path = './weights/smpl_data/smpl/smpl_mean_params.npz'
|
| 5 |
+
|
| 6 |
+
dinov2_vitb14_path = './weights/dinov2/dinov2_vitb14_pretrain.pth'
|
| 7 |
+
dinov2_vitl14_path = './weights/dinov2/dinov2_vitl14_pretrain.pth'
|
configs/run/demo.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model: default
|
| 2 |
+
|
| 3 |
+
pretrain: True
|
| 4 |
+
pretrain_path: './weights/sat_hmr/sat_644.pth'
|
| 5 |
+
|
| 6 |
+
input_dir: './demo'
|
| 7 |
+
output_dir: './demo_results'
|
| 8 |
+
conf_thresh: [0.3]
|
| 9 |
+
infer_batch_size: 1
|
| 10 |
+
infer_num_workers: 8
|
| 11 |
+
distributed_infer: True
|
| 12 |
+
|
datasets/__init__.py
ADDED
|
File without changes
|
datasets/agora.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data.dataset import Dataset
|
| 4 |
+
import os
|
| 5 |
+
from configs.paths import dataset_root
|
| 6 |
+
import copy
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from .base import BASE
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AGORA(BASE):
|
| 12 |
+
def __init__(self, split='train', **kwargs):
|
| 13 |
+
super(AGORA, self).__init__(**kwargs)
|
| 14 |
+
assert split in ['train','test','validation']
|
| 15 |
+
|
| 16 |
+
self.ds_name = 'agora'
|
| 17 |
+
self.split = split
|
| 18 |
+
self.dataset_path = os.path.join(dataset_root,'agora')
|
| 19 |
+
|
| 20 |
+
# no annotations are available for AGORA-test
|
| 21 |
+
if split == 'test':
|
| 22 |
+
self.mode = 'infer'
|
| 23 |
+
self.img_names = os.listdir(os.path.join(self.dataset_path, self.split))
|
| 24 |
+
else:
|
| 25 |
+
if self.split == 'train':
|
| 26 |
+
annots_path = os.path.join(self.dataset_path,'smpl_neutral_annots','annots_smpl_{}_fit.npz'.format(split))
|
| 27 |
+
else:
|
| 28 |
+
annots_path = os.path.join(self.dataset_path,'smpl_neutral_annots','annots_smpl_{}.npz'.format(split))
|
| 29 |
+
self.annots = np.load(annots_path, allow_pickle=True)['annots'][()]
|
| 30 |
+
self.img_names = list(self.annots.keys())
|
| 31 |
+
|
| 32 |
+
def __len__(self):
|
| 33 |
+
return len(self.img_names)
|
| 34 |
+
|
| 35 |
+
def get_raw_data(self, idx):
|
| 36 |
+
img_id = idx % len(self.img_names)
|
| 37 |
+
img_name = self.img_names[img_id]
|
| 38 |
+
|
| 39 |
+
if self.mode == 'infer':
|
| 40 |
+
img_path = os.path.join(self.dataset_path, self.split,img_name)
|
| 41 |
+
raw_data = {'img_path': img_path,
|
| 42 |
+
'img_name': img_name,
|
| 43 |
+
'ds': 'agora'
|
| 44 |
+
}
|
| 45 |
+
return raw_data
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
annots = copy.deepcopy(self.annots[img_name])
|
| 49 |
+
img_path = os.path.join(self.dataset_path, self.split,img_name)
|
| 50 |
+
|
| 51 |
+
valid_mask = np.where(annots['isValid'])[0]
|
| 52 |
+
|
| 53 |
+
# this should not happen
|
| 54 |
+
if len(valid_mask) ==0:
|
| 55 |
+
print(img_name, 'lack valid person')
|
| 56 |
+
exit(0)
|
| 57 |
+
|
| 58 |
+
cam_intrinsics = torch.from_numpy(np.array(annots['cam_intrinsics']))
|
| 59 |
+
cam_rot = torch.from_numpy(np.array(annots['cam_rot'])[valid_mask])
|
| 60 |
+
cam_trans = torch.from_numpy(np.array(annots['cam_trans'])[valid_mask])
|
| 61 |
+
|
| 62 |
+
betas_list = []
|
| 63 |
+
poses_list = []
|
| 64 |
+
transl_list = []
|
| 65 |
+
|
| 66 |
+
kid = []
|
| 67 |
+
|
| 68 |
+
if self.mode == 'eval':
|
| 69 |
+
occ_leval_list = []
|
| 70 |
+
|
| 71 |
+
for pNum in range(len(annots['isValid'])):
|
| 72 |
+
if not annots['isValid'][pNum]:
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
gt = annots['smpl_gt'][pNum]
|
| 76 |
+
betas = gt['betas'].flatten()[:10]
|
| 77 |
+
betas_list.append(torch.from_numpy(betas))
|
| 78 |
+
full_poses = torch.cat([torch.from_numpy(gt['global_orient'].flatten()), torch.from_numpy(gt['body_pose'].flatten())])
|
| 79 |
+
poses_list.append(full_poses)
|
| 80 |
+
transl_list.append(torch.from_numpy(gt['transl'].flatten()))
|
| 81 |
+
|
| 82 |
+
kid.append(annots['kid'][pNum])
|
| 83 |
+
|
| 84 |
+
if self.mode == 'eval':
|
| 85 |
+
occ_leval_list.append(int(annots['occlusion'][pNum]//10))
|
| 86 |
+
|
| 87 |
+
betas = torch.stack(betas_list)
|
| 88 |
+
poses = torch.stack(poses_list)
|
| 89 |
+
transl = torch.stack(transl_list)
|
| 90 |
+
|
| 91 |
+
raw_data={'img_path': img_path,
|
| 92 |
+
'ds': 'agora',
|
| 93 |
+
'pnum': len(betas),
|
| 94 |
+
'betas': betas.float(),
|
| 95 |
+
'poses': poses.float(),
|
| 96 |
+
'transl': transl.float(),
|
| 97 |
+
'kid': torch.tensor(kid),
|
| 98 |
+
'cam_rot': cam_rot.float(),
|
| 99 |
+
'cam_trans': cam_trans.float(),
|
| 100 |
+
'cam_intrinsics':cam_intrinsics.float(),
|
| 101 |
+
'3d_valid': True,
|
| 102 |
+
'age_valid': True,
|
| 103 |
+
'detect_all_people':True
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
if self.mode == 'eval':
|
| 107 |
+
raw_data['occ_level'] = torch.tensor(occ_leval_list)
|
| 108 |
+
|
| 109 |
+
return raw_data
|
| 110 |
+
|
| 111 |
+
|
datasets/base.py
ADDED
|
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
from torch.utils.data.dataset import Dataset
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
from utils.visualization import tensor_to_BGR, vis_meshes_img, vis_boxes, vis_scale_img, pad_img, get_colors_rgb, vis_sat
|
| 7 |
+
from utils.transforms import unNormalize, to_zorder
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import math
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import cv2
|
| 12 |
+
import torch
|
| 13 |
+
import copy
|
| 14 |
+
from math import radians,sin,cos
|
| 15 |
+
from utils import constants
|
| 16 |
+
from utils.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh
|
| 17 |
+
from utils.constants import smpl_24_flip, smpl_root_idx
|
| 18 |
+
from utils.map import gen_scale_map, build_z_map
|
| 19 |
+
from configs.paths import smpl_model_path
|
| 20 |
+
from models.human_models import SMPL_Layer, smpl_gendered
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class BASE(Dataset):
|
| 24 |
+
def __init__(self, input_size = 1288, aug = True, mode = 'train',
|
| 25 |
+
human_type = 'smpl',
|
| 26 |
+
sat_cfg = None,
|
| 27 |
+
aug_cfg = None):
|
| 28 |
+
self.input_size = input_size
|
| 29 |
+
self.aug = aug
|
| 30 |
+
if mode not in ['train', 'eval', 'infer']:
|
| 31 |
+
raise NotImplementedError
|
| 32 |
+
if human_type not in ['smpl', 'no']:
|
| 33 |
+
raise NotImplementedError
|
| 34 |
+
self.mode = mode
|
| 35 |
+
self.human_type = human_type
|
| 36 |
+
assert sat_cfg is not None
|
| 37 |
+
self.use_sat = sat_cfg['use_sat']
|
| 38 |
+
self.sat_cfg = sat_cfg
|
| 39 |
+
|
| 40 |
+
if self.use_sat:
|
| 41 |
+
assert input_size % 56 == 0
|
| 42 |
+
|
| 43 |
+
if self.mode == 'train' and aug_cfg is None:
|
| 44 |
+
aug_cfg = {'rot_range': [-15, 15],
|
| 45 |
+
'scale_range': [0.8, 1.8],
|
| 46 |
+
'flip_ratio': 0.5,
|
| 47 |
+
'crop_ratio': 0.}
|
| 48 |
+
self.aug_cfg = aug_cfg
|
| 49 |
+
|
| 50 |
+
if human_type == 'smpl':
|
| 51 |
+
self.poses_flip = smpl_24_flip
|
| 52 |
+
self.num_poses = 24
|
| 53 |
+
self.num_betas = 10
|
| 54 |
+
self.num_kpts = 45
|
| 55 |
+
self.human_model = smpl_gendered
|
| 56 |
+
|
| 57 |
+
self.vis_thresh = 4 # least num visible kpts for a valid individual
|
| 58 |
+
|
| 59 |
+
self.img_keys = ['img_path', 'ds',
|
| 60 |
+
'pnum', 'img_size',
|
| 61 |
+
'resize_rate', 'cam_intrinsics',
|
| 62 |
+
'3d_valid', 'detect_all_people',
|
| 63 |
+
'scale_map', 'scale_map_pos', 'scale_map_hw']
|
| 64 |
+
self.human_keys = ['boxes', 'labels',
|
| 65 |
+
'poses', 'betas',
|
| 66 |
+
'transl', 'verts',
|
| 67 |
+
'j3ds', 'j2ds', 'j2ds_mask',
|
| 68 |
+
'depths', 'focals', 'genders']
|
| 69 |
+
|
| 70 |
+
z_depth = math.ceil(math.log2(self.input_size//28))
|
| 71 |
+
self.z_order_map, self.y_coords, self.x_coords = build_z_map(z_depth)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_raw_data(self, idx):
|
| 76 |
+
raise NotImplementedError
|
| 77 |
+
|
| 78 |
+
def get_aug_dict(self):
|
| 79 |
+
if self.aug:
|
| 80 |
+
rot = random.uniform(*self.aug_cfg['rot_range'])
|
| 81 |
+
flip = random.random() <= self.aug_cfg['flip_ratio']
|
| 82 |
+
scale = random.uniform(*self.aug_cfg['scale_range'])
|
| 83 |
+
crop = random.random() <= self.aug_cfg['crop_ratio']
|
| 84 |
+
else:
|
| 85 |
+
rot = 0.
|
| 86 |
+
flip = False
|
| 87 |
+
scale = 1.
|
| 88 |
+
crop = False
|
| 89 |
+
|
| 90 |
+
return {'rot':rot, 'flip':flip, 'scale':scale, 'crop': crop}
|
| 91 |
+
|
| 92 |
+
def process_img(self, img, meta_data, rot = 0., flip = False, scale = 1.0, crop = False):
|
| 93 |
+
# randomly crop (similar to scale)
|
| 94 |
+
if self.mode == 'train' and crop:
|
| 95 |
+
|
| 96 |
+
h, w = img.shape[:2]
|
| 97 |
+
if h < w :
|
| 98 |
+
clip_ratio = random.uniform(0.5, 0.9)
|
| 99 |
+
tgt_h, tgt_w = int(h*clip_ratio), int(w*clip_ratio)
|
| 100 |
+
|
| 101 |
+
img = img[:tgt_h,(w-tgt_w)//2:(w+tgt_w)//2,:].copy()
|
| 102 |
+
cam_intrinsics = meta_data['cam_intrinsics']
|
| 103 |
+
cam_intrinsics[:,0,2] -= (w-tgt_w)//2
|
| 104 |
+
meta_data.update({'cam_intrinsics': cam_intrinsics})
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# resize
|
| 108 |
+
img_size = torch.tensor(img.shape[:2])
|
| 109 |
+
if img_size[1] >= img_size[0]:
|
| 110 |
+
resize_rate = self.input_size/img_size[1]
|
| 111 |
+
img = cv2.resize(img,dsize=(self.input_size,int(resize_rate*img_size[0])))
|
| 112 |
+
img_size = torch.tensor([int(resize_rate*img_size[0]),self.input_size])
|
| 113 |
+
else:
|
| 114 |
+
resize_rate = self.input_size/img_size[0]
|
| 115 |
+
img = cv2.resize(img,dsize=(int(resize_rate*img_size[1]),self.input_size))
|
| 116 |
+
img_size = torch.tensor([self.input_size,int(resize_rate*img_size[1])])
|
| 117 |
+
meta_data.update({'img_size': img_size, 'resize_rate': resize_rate})
|
| 118 |
+
|
| 119 |
+
# flip
|
| 120 |
+
if flip:
|
| 121 |
+
img = np.flip(img, axis = 1)
|
| 122 |
+
rot = -rot
|
| 123 |
+
|
| 124 |
+
# rot and scale
|
| 125 |
+
img_valid = np.full((img.shape[0], img.shape[1]), 255, dtype = np.uint8)
|
| 126 |
+
M = cv2.getRotationMatrix2D((int(img_size[1]/2),int(img_size[0]/2)), rot, scale)
|
| 127 |
+
img = cv2.warpAffine(img, M, dsize = (img.shape[1],img.shape[0]))
|
| 128 |
+
img_valid = cv2.warpAffine(img_valid, M, dsize = (img.shape[1],img.shape[0]))
|
| 129 |
+
meta_data.update({'img_valid': img_valid})
|
| 130 |
+
|
| 131 |
+
return img
|
| 132 |
+
|
| 133 |
+
def occlusion_aug(self, meta_data):
|
| 134 |
+
occ_boxes = []
|
| 135 |
+
imght, imgwidth = meta_data['img_size']
|
| 136 |
+
for bbox in box_cxcywh_to_xyxy(meta_data['boxes']):
|
| 137 |
+
bbox = bbox.clone()
|
| 138 |
+
bbox *= self.input_size
|
| 139 |
+
xmin, ymin = bbox[:2]
|
| 140 |
+
xmax, ymax = bbox[2:]
|
| 141 |
+
|
| 142 |
+
if random.random() <= 0.6:
|
| 143 |
+
counter = 0
|
| 144 |
+
while True:
|
| 145 |
+
# force to break if no suitable occlusion
|
| 146 |
+
if counter > 5:
|
| 147 |
+
synth_ymin, synth_h, synth_xmin, synth_w = 0, 0, 0, 0
|
| 148 |
+
break
|
| 149 |
+
counter += 1
|
| 150 |
+
|
| 151 |
+
area_min = 0.0
|
| 152 |
+
area_max = 0.3
|
| 153 |
+
synth_area = (random.random() * (area_max - area_min) + area_min) * (xmax - xmin) * (ymax - ymin)
|
| 154 |
+
|
| 155 |
+
ratio_min = 0.5
|
| 156 |
+
ratio_max = 1 / 0.5
|
| 157 |
+
synth_ratio = (random.random() * (ratio_max - ratio_min) + ratio_min)
|
| 158 |
+
|
| 159 |
+
synth_h = math.sqrt(synth_area * synth_ratio)
|
| 160 |
+
synth_w = math.sqrt(synth_area / synth_ratio)
|
| 161 |
+
synth_xmin = random.random() * ((xmax - xmin) - synth_w - 1) + xmin
|
| 162 |
+
synth_ymin = random.random() * ((ymax - ymin) - synth_h - 1) + ymin
|
| 163 |
+
|
| 164 |
+
if synth_xmin >= 0 and synth_ymin >= 0 and synth_xmin + synth_w < imgwidth and synth_ymin + synth_h < imght:
|
| 165 |
+
synth_xmin = int(synth_xmin)
|
| 166 |
+
synth_ymin = int(synth_ymin)
|
| 167 |
+
synth_w = int(synth_w)
|
| 168 |
+
synth_h = int(synth_h)
|
| 169 |
+
break
|
| 170 |
+
else:
|
| 171 |
+
synth_ymin, synth_h, synth_xmin, synth_w = 0, 0, 0, 0
|
| 172 |
+
occ_boxes.append((synth_ymin, synth_h, synth_xmin, synth_w))
|
| 173 |
+
return occ_boxes
|
| 174 |
+
|
| 175 |
+
def get_boxes(self, meta_data):
|
| 176 |
+
j2ds = meta_data['j2ds']
|
| 177 |
+
j2ds_mask = meta_data['j2ds_mask']
|
| 178 |
+
pnum = meta_data['pnum']
|
| 179 |
+
|
| 180 |
+
bboxes_list = []
|
| 181 |
+
|
| 182 |
+
for i in range(pnum):
|
| 183 |
+
kpts = j2ds[i].clone()
|
| 184 |
+
min_xy = kpts.min(dim = 0)[0]
|
| 185 |
+
max_xy = kpts.max(dim = 0)[0]
|
| 186 |
+
bbox_xyxy = torch.cat([min_xy, max_xy], dim = 0)
|
| 187 |
+
bboxes_list.append(bbox_xyxy)
|
| 188 |
+
|
| 189 |
+
imght, imgwidth = meta_data['img_size']
|
| 190 |
+
boxes = box_xyxy_to_cxcywh(torch.stack(bboxes_list)) / self.input_size
|
| 191 |
+
boxes[...,2:] *= 1.2
|
| 192 |
+
boxes = box_cxcywh_to_xyxy(boxes)
|
| 193 |
+
boxes[...,[0,2]] = boxes[...,[0,2]].clamp(min=0.01,max=(imgwidth-1)/self.input_size)
|
| 194 |
+
boxes[...,[1,3]] = boxes[...,[1,3]].clamp(min=0.01,max=(imght-1)/self.input_size)
|
| 195 |
+
boxes = box_xyxy_to_cxcywh(boxes)
|
| 196 |
+
|
| 197 |
+
meta_data.update({'boxes': boxes})
|
| 198 |
+
|
| 199 |
+
def process_cam(self, meta_data, rot = 0., flip = False, scale = 1.):
|
| 200 |
+
img_size = meta_data['img_size']
|
| 201 |
+
resize_rate = meta_data['resize_rate']
|
| 202 |
+
rot_aug_mat = meta_data['rot_aug_mat']
|
| 203 |
+
cam_intrinsics = meta_data['cam_intrinsics']
|
| 204 |
+
# cam_int
|
| 205 |
+
# resize
|
| 206 |
+
cam_intrinsics[:,0:2,2] *= resize_rate * scale
|
| 207 |
+
cam_intrinsics[:,[0,1],[0,1]] *= resize_rate * scale
|
| 208 |
+
cam_intrinsics[:,0,2] += (1-scale)*img_size[1]/2
|
| 209 |
+
cam_intrinsics[:,1,2] += (1-scale)*img_size[0]/2
|
| 210 |
+
# rotation
|
| 211 |
+
princpt = cam_intrinsics[:,0:2,2].clone()
|
| 212 |
+
princpt[...,0] -= img_size[1]/2
|
| 213 |
+
princpt[...,1] -= img_size[0]/2
|
| 214 |
+
princpt = torch.matmul(princpt,rot_aug_mat[:2,:2].transpose(-1,-2))
|
| 215 |
+
princpt[...,0] += img_size[1]/2
|
| 216 |
+
princpt[...,1] += img_size[0]/2
|
| 217 |
+
cam_intrinsics[:,0:2,2] = princpt
|
| 218 |
+
# flip
|
| 219 |
+
if flip:
|
| 220 |
+
cam_intrinsics[:,0,2] = img_size[1]-cam_intrinsics[:,0,2]
|
| 221 |
+
meta_data.update({'cam_intrinsics': cam_intrinsics})
|
| 222 |
+
|
| 223 |
+
#cam_ext
|
| 224 |
+
new_cam_rot = torch.matmul(rot_aug_mat.unsqueeze(0),meta_data['cam_rot'])
|
| 225 |
+
new_cam_trans = torch.matmul(meta_data['cam_trans'],rot_aug_mat.transpose(-1,-2))
|
| 226 |
+
meta_data.update({'cam_rot': new_cam_rot,'cam_trans':new_cam_trans})
|
| 227 |
+
|
| 228 |
+
def process_smpl(self, meta_data, rot = 0., flip = False, scale = 1.):
|
| 229 |
+
poses = meta_data['poses']
|
| 230 |
+
bs = poses.shape[0]
|
| 231 |
+
assert poses.ndim == 2
|
| 232 |
+
assert tuple(poses.shape) == (bs, self.num_poses*3)
|
| 233 |
+
# Merge rotation to smpl global_orient
|
| 234 |
+
global_orient = poses[:,:3].clone()
|
| 235 |
+
cam_rot = meta_data['cam_rot'].numpy()
|
| 236 |
+
for i in range(global_orient.shape[0]):
|
| 237 |
+
root_pose = global_orient[i].view(1, 3).numpy()
|
| 238 |
+
R = cam_rot[i].reshape(3,3)
|
| 239 |
+
root_pose, _ = cv2.Rodrigues(root_pose)
|
| 240 |
+
root_pose, _ = cv2.Rodrigues(np.dot(R, root_pose))
|
| 241 |
+
root_pose = torch.from_numpy(root_pose).flatten()
|
| 242 |
+
global_orient[i] = root_pose
|
| 243 |
+
poses[:,:3] = global_orient
|
| 244 |
+
|
| 245 |
+
# Flip smpl parameters
|
| 246 |
+
if flip:
|
| 247 |
+
poses = poses.reshape(bs, self.num_poses, 3)
|
| 248 |
+
poses = poses[:, self.poses_flip, :]
|
| 249 |
+
poses[..., 1:3] *= -1 # multiply -1 to y and z axis of axis-angle
|
| 250 |
+
poses = poses.reshape(bs, -1)
|
| 251 |
+
|
| 252 |
+
# Update all pose params
|
| 253 |
+
meta_data.update({'poses': poses})
|
| 254 |
+
|
| 255 |
+
# Get vertices and joints in cam_coords
|
| 256 |
+
with torch.no_grad():
|
| 257 |
+
smpl_kwargs = {'poses': meta_data['poses'], 'betas': meta_data['betas']}
|
| 258 |
+
if 'genders' in meta_data:
|
| 259 |
+
smpl_kwargs.update({'genders': meta_data['genders']})
|
| 260 |
+
verts, j3ds = self.human_model(**smpl_kwargs)
|
| 261 |
+
|
| 262 |
+
j3ds = j3ds[:, :self.num_kpts, :]
|
| 263 |
+
root = j3ds[:,smpl_root_idx,:].clone() # smpl root
|
| 264 |
+
# new translation in cam_coords
|
| 265 |
+
transl = torch.bmm((root+meta_data['transl']).reshape(-1,1,3),meta_data['cam_rot'].transpose(-1,-2)).reshape(-1,3)\
|
| 266 |
+
+meta_data['cam_trans']-root
|
| 267 |
+
if flip:
|
| 268 |
+
transl[...,0] = -transl[...,0]
|
| 269 |
+
|
| 270 |
+
meta_data.update({'transl': transl})
|
| 271 |
+
|
| 272 |
+
verts = verts + transl.reshape(-1,1,3)
|
| 273 |
+
j3ds = j3ds + transl.reshape(-1,1,3)
|
| 274 |
+
meta_data.update({'verts': verts, 'j3ds': j3ds})
|
| 275 |
+
|
| 276 |
+
def project_joints(self, meta_data):
|
| 277 |
+
j3ds = meta_data['j3ds']
|
| 278 |
+
cam_intrinsics = meta_data['cam_intrinsics']
|
| 279 |
+
j2ds_homo = torch.matmul(j3ds,cam_intrinsics.transpose(-1,-2))
|
| 280 |
+
j2ds = j2ds_homo[...,:2]/(j2ds_homo[...,2,None])
|
| 281 |
+
|
| 282 |
+
meta_data.update({'j3ds': j3ds, 'j2ds': j2ds})
|
| 283 |
+
|
| 284 |
+
def check_visibility(self, meta_data):
|
| 285 |
+
img_valid = meta_data['img_valid']
|
| 286 |
+
img_size = meta_data['img_size']
|
| 287 |
+
|
| 288 |
+
j2ds = meta_data['j2ds']
|
| 289 |
+
j2ds_mask = meta_data['j2ds_mask'] if 'j2ds_mask' in meta_data else torch.ones_like(j2ds, dtype=bool)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
j2ds_vis = torch.from_numpy(img_valid[j2ds[...,1].int().clip(0,img_size[0]-1), j2ds[...,0].int().clip(0,img_size[1]-1)] > 0)
|
| 293 |
+
j2ds_vis &= (j2ds[...,1] >= 0) & (j2ds[...,1] < img_size[0])
|
| 294 |
+
j2ds_vis &= (j2ds[...,0] >= 0) & (j2ds[...,0] < img_size[1])
|
| 295 |
+
|
| 296 |
+
j2ds_invalid = ~j2ds_vis
|
| 297 |
+
j2ds_mask[j2ds_invalid] = False
|
| 298 |
+
meta_data.update({'j2ds_mask': j2ds_mask})
|
| 299 |
+
|
| 300 |
+
vis_cnt = j2ds_mask[...,0].sum(dim = -1) # num of visible joints per person
|
| 301 |
+
valid_msk = (vis_cnt >= self.vis_thresh)
|
| 302 |
+
|
| 303 |
+
pnum = valid_msk.sum().item()
|
| 304 |
+
|
| 305 |
+
if pnum == 0:
|
| 306 |
+
meta_data['pnum'] = pnum
|
| 307 |
+
return
|
| 308 |
+
|
| 309 |
+
if pnum < meta_data['pnum']:
|
| 310 |
+
meta_data['pnum'] = pnum
|
| 311 |
+
for key in self.human_keys:
|
| 312 |
+
if key in meta_data:
|
| 313 |
+
if isinstance(meta_data[key], list):
|
| 314 |
+
meta_data[key] = np.array(meta_data[key])[valid_msk].tolist()
|
| 315 |
+
else:
|
| 316 |
+
meta_data[key] = meta_data[key][valid_msk]
|
| 317 |
+
if 'cam_intrinsics' in meta_data and len(meta_data['cam_intrinsics']) > 1:
|
| 318 |
+
meta_data['cam_intrinsics'] = meta_data['cam_intrinsics'][valid_msk]
|
| 319 |
+
|
| 320 |
+
return
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def process_data(self, img, raw_data, rot = 0., flip = False, scale = 1., crop = False):
|
| 324 |
+
meta_data = copy.deepcopy(raw_data)
|
| 325 |
+
# prepare rotation augmentation mat.
|
| 326 |
+
rot_aug_mat = torch.tensor([[cos(radians(-rot)), -sin(radians(-rot)), 0.],
|
| 327 |
+
[sin(radians(-rot)), cos(radians(-rot)), 0.],
|
| 328 |
+
[0., 0., 1.]])
|
| 329 |
+
meta_data.update({'rot_aug_mat': rot_aug_mat})
|
| 330 |
+
|
| 331 |
+
img = self.process_img(img, meta_data, rot, flip, scale, crop)
|
| 332 |
+
|
| 333 |
+
self.process_cam(meta_data, rot, flip, scale)
|
| 334 |
+
self.process_smpl(meta_data, rot, flip, scale)
|
| 335 |
+
self.project_joints(meta_data)
|
| 336 |
+
self.check_visibility(meta_data)
|
| 337 |
+
matcher_vis = meta_data['j2ds_mask'][:,:22,0].sum(dim = -1) # num of visible joints used in Hungarian Matcher
|
| 338 |
+
if meta_data['pnum'] == 0 or not torch.all(matcher_vis):
|
| 339 |
+
if self.mode == 'train':
|
| 340 |
+
meta_data['pnum'] = 0
|
| 341 |
+
return img, meta_data
|
| 342 |
+
|
| 343 |
+
j3ds = meta_data['j3ds']
|
| 344 |
+
depths = j3ds[:, smpl_root_idx, [2]].clone()
|
| 345 |
+
if len(meta_data['cam_intrinsics']) == 1:
|
| 346 |
+
focals = torch.full_like(depths, meta_data['cam_intrinsics'][0,0,0])
|
| 347 |
+
else:
|
| 348 |
+
focals = meta_data['cam_intrinsics'][:,0,0][:, None]
|
| 349 |
+
depths = torch.cat([depths, depths/focals],dim=-1)
|
| 350 |
+
meta_data.update({'depths': depths, 'focals': focals})
|
| 351 |
+
|
| 352 |
+
self.get_boxes(meta_data)
|
| 353 |
+
|
| 354 |
+
meta_data.update({'labels': torch.zeros(meta_data['pnum'], dtype=int)})
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
# VI. Occlusion augmentation
|
| 358 |
+
if self.aug:
|
| 359 |
+
occ_boxes = self.occlusion_aug(meta_data)
|
| 360 |
+
for (synth_ymin, synth_h, synth_xmin, synth_w) in occ_boxes:
|
| 361 |
+
img[synth_ymin:synth_ymin + synth_h, synth_xmin:synth_xmin + synth_w, :] = np.random.rand(synth_h, synth_w, 3) * 255
|
| 362 |
+
|
| 363 |
+
if self.use_sat:
|
| 364 |
+
# scale map
|
| 365 |
+
boxes = meta_data['boxes']
|
| 366 |
+
scales = boxes[:,2:].norm(p=2,dim=1)
|
| 367 |
+
v3ds = meta_data['verts']
|
| 368 |
+
depths_norm = meta_data['depths'][:,1]
|
| 369 |
+
cam_intrinsics = meta_data['cam_intrinsics']
|
| 370 |
+
sorted_idx = torch.argsort(depths_norm, descending=True)
|
| 371 |
+
map_size = (meta_data['img_size'] + 27)//28
|
| 372 |
+
|
| 373 |
+
scale_map = gen_scale_map(scales[sorted_idx], v3ds[sorted_idx],
|
| 374 |
+
faces = self.human_model.faces,
|
| 375 |
+
cam_intrinsics = cam_intrinsics[sorted_idx] if len(cam_intrinsics) > 1 else cam_intrinsics,
|
| 376 |
+
map_size = map_size,
|
| 377 |
+
patch_size = 28,
|
| 378 |
+
pad = True)
|
| 379 |
+
scale_map_z, _, pos_y, pos_x = to_zorder(scale_map,
|
| 380 |
+
z_order_map = self.z_order_map,
|
| 381 |
+
y_coords = self.y_coords,
|
| 382 |
+
x_coords = self.x_coords)
|
| 383 |
+
meta_data['scale_map'] = scale_map_z
|
| 384 |
+
meta_data['scale_map_pos'] = {'pos_y': pos_y, 'pos_x': pos_x}
|
| 385 |
+
meta_data['scale_map_hw'] = scale_map.shape[:2]
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
return img, meta_data
|
| 390 |
+
|
| 391 |
+
def __getitem__(self, index):
|
| 392 |
+
|
| 393 |
+
raw_data = self.get_raw_data(index)
|
| 394 |
+
|
| 395 |
+
# Load original image
|
| 396 |
+
ori_img = cv2.imread(raw_data['img_path'])
|
| 397 |
+
if raw_data['ds'] == 'bedlam' and 'closeup' in raw_data['img_path']:
|
| 398 |
+
ori_img = cv2.rotate(ori_img, cv2.ROTATE_90_CLOCKWISE)
|
| 399 |
+
img_size = torch.tensor(ori_img.shape[:2])
|
| 400 |
+
raw_data.update({'img_size': img_size})
|
| 401 |
+
|
| 402 |
+
if self.mode == 'train':
|
| 403 |
+
cnt = 0
|
| 404 |
+
while (True):
|
| 405 |
+
aug_dict = self.get_aug_dict()
|
| 406 |
+
img, meta_data = self.process_data(ori_img, raw_data, **aug_dict)
|
| 407 |
+
if meta_data['pnum'] > 0:
|
| 408 |
+
break
|
| 409 |
+
cnt+=1
|
| 410 |
+
if cnt >= 10:
|
| 411 |
+
aug_dict.update({'rot':0., 'scale':1., 'crop': False})
|
| 412 |
+
img, meta_data = self.process_data(ori_img, raw_data, **aug_dict)
|
| 413 |
+
if meta_data['pnum'] == 0:
|
| 414 |
+
print('skipping: ' + meta_data['img_path'])
|
| 415 |
+
return self.__getitem__(index + 1)
|
| 416 |
+
|
| 417 |
+
elif self.mode == 'eval':
|
| 418 |
+
assert not self.aug, f'No need to use augmentation when mode is {self.mode}!'
|
| 419 |
+
aug_dict = self.get_aug_dict()
|
| 420 |
+
img, meta_data = self.process_data(ori_img, raw_data, **aug_dict)
|
| 421 |
+
|
| 422 |
+
else:
|
| 423 |
+
assert not self.aug, f'No need to use augmentation when mode is {self.mode}!'
|
| 424 |
+
meta_data = raw_data
|
| 425 |
+
img = self.process_img(ori_img, meta_data)
|
| 426 |
+
|
| 427 |
+
# delete unwanted keys
|
| 428 |
+
if self.mode == 'train':
|
| 429 |
+
for key in list(meta_data.keys()):
|
| 430 |
+
if key not in self.img_keys and key not in self.human_keys:
|
| 431 |
+
del meta_data[key]
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
if self.aug:
|
| 435 |
+
array2tensor = transforms.Compose([
|
| 436 |
+
transforms.ColorJitter(0.2, 0.2, 0.2),
|
| 437 |
+
transforms.ToTensor(),
|
| 438 |
+
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
|
| 439 |
+
])
|
| 440 |
+
else:
|
| 441 |
+
array2tensor = transforms.Compose([
|
| 442 |
+
transforms.ToTensor(),
|
| 443 |
+
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
|
| 444 |
+
])
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
patch_size = 14
|
| 448 |
+
if self.use_sat:
|
| 449 |
+
patch_size = 56
|
| 450 |
+
pad_img = np.zeros((math.ceil(img.shape[0]/patch_size)*patch_size, math.ceil(img.shape[1]/patch_size)*patch_size, 3), dtype=img.dtype)
|
| 451 |
+
pad_img[:img.shape[0], :img.shape[1]] = img
|
| 452 |
+
assert max(pad_img.shape[:2]) == self.input_size
|
| 453 |
+
pad_img = Image.fromarray(pad_img[:,:,::-1].copy())
|
| 454 |
+
norm_img = array2tensor(pad_img)
|
| 455 |
+
|
| 456 |
+
if 'j2ds_mask' in meta_data:
|
| 457 |
+
meta_data['j2ds_mask'][:,:,:] = True
|
| 458 |
+
|
| 459 |
+
return norm_img, meta_data
|
| 460 |
+
|
| 461 |
+
def visualize(self, results_save_dir = None, vis_num = 100):
|
| 462 |
+
if results_save_dir is None:
|
| 463 |
+
results_save_dir = os.path.join('datasets_visualization',f'{self.ds_name}_{self.split}')
|
| 464 |
+
os.makedirs(results_save_dir, exist_ok=True)
|
| 465 |
+
|
| 466 |
+
vis_interval = len(self)//vis_num
|
| 467 |
+
|
| 468 |
+
for idx in tqdm(range(len(self))):
|
| 469 |
+
if idx % vis_interval != 0:
|
| 470 |
+
continue
|
| 471 |
+
|
| 472 |
+
norm_img, targets = self.__getitem__(idx)
|
| 473 |
+
|
| 474 |
+
ori_img = tensor_to_BGR(unNormalize(norm_img).cpu())
|
| 475 |
+
img_name = targets['img_path'].split('/')[-1].split('.')[-2]
|
| 476 |
+
pnum = targets['pnum']
|
| 477 |
+
|
| 478 |
+
if 'verts' in targets:
|
| 479 |
+
colors = get_colors_rgb(len(targets['verts']))
|
| 480 |
+
mesh_img = vis_meshes_img(img = ori_img.copy(),
|
| 481 |
+
verts = targets['verts'],
|
| 482 |
+
smpl_faces = self.human_model.faces,
|
| 483 |
+
cam_intrinsics = targets['cam_intrinsics'].cpu(),
|
| 484 |
+
colors=colors,
|
| 485 |
+
padding=False)
|
| 486 |
+
cv2.imwrite(os.path.join(results_save_dir,f'{idx}_{img_name}_mesh.jpg'), mesh_img)
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
if 'boxes' in targets:
|
| 490 |
+
gt_img = ori_img.copy()
|
| 491 |
+
boxes = box_cxcywh_to_xyxy(targets['boxes']) * self.input_size
|
| 492 |
+
for i, bbox in enumerate(boxes):
|
| 493 |
+
bbox = bbox.int().tolist()
|
| 494 |
+
cv2.rectangle(gt_img, (bbox[0], bbox[1]), (bbox[2], bbox[3]),
|
| 495 |
+
color=(0,0,255), thickness = 2 )
|
| 496 |
+
|
| 497 |
+
cv2.imwrite(os.path.join(results_save_dir,f'{idx}_{img_name}_boxes.jpg'), gt_img)
|
| 498 |
+
|
| 499 |
+
if 'scale_map' in targets:
|
| 500 |
+
gt_img = ori_img.copy()
|
| 501 |
+
flatten_map = targets['scale_map']
|
| 502 |
+
ys, xs = targets['scale_map_pos']['pos_y'], targets['scale_map_pos']['pos_x']
|
| 503 |
+
h, w = targets['scale_map_hw']
|
| 504 |
+
scale_map = torch.zeros((h,w,2))
|
| 505 |
+
scale_map[ys,xs] = flatten_map
|
| 506 |
+
img = vis_scale_img(gt_img, scale_map, patch_size=28)
|
| 507 |
+
|
| 508 |
+
cv2.imwrite(os.path.join(results_save_dir,f'{idx}_{img_name}_scales.jpg'), img)
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
# if 'j2ds' in targets:
|
| 512 |
+
# gt_img = ori_img.copy()
|
| 513 |
+
# j2ds = targets['j2ds']
|
| 514 |
+
# j2ds_mask = targets['j2ds_mask']
|
| 515 |
+
# for kpts, valids in zip(j2ds, j2ds_mask):
|
| 516 |
+
# for kpt, valid in zip(kpts, valids):
|
| 517 |
+
# if not valid.all():
|
| 518 |
+
# continue
|
| 519 |
+
# kpt_int = kpt.numpy().astype(int)
|
| 520 |
+
# cv2.circle(gt_img, kpt_int, 2, (0, 0, 255), -1)
|
| 521 |
+
# cv2.imwrite(os.path.join(results_save_dir,f'{idx}_{img_name}_joints.png'), np.hstack([ori_img, gt_img]))
|
datasets/bedlam.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data.dataset import Dataset
|
| 4 |
+
import os
|
| 5 |
+
from configs.paths import dataset_root
|
| 6 |
+
import copy
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from .base import BASE
|
| 9 |
+
|
| 10 |
+
class BEDLAM(BASE):
|
| 11 |
+
def __init__(self, split='train_6fps',**kwargs):
|
| 12 |
+
super(BEDLAM, self).__init__(**kwargs)
|
| 13 |
+
assert split in ['train_1fps','train_3fps','train_6fps','validation_6fps']
|
| 14 |
+
assert not self.kid_offset
|
| 15 |
+
|
| 16 |
+
self.ds_name = 'bedlam'
|
| 17 |
+
self.dataset_path = os.path.join(dataset_root,'bedlam')
|
| 18 |
+
annots_path = os.path.join(self.dataset_path,f'bedlam_smpl_{split}.npz')
|
| 19 |
+
self.annots = np.load(annots_path, allow_pickle=True)['annots'][()]
|
| 20 |
+
self.img_names = list(self.annots.keys())
|
| 21 |
+
self.split = 'train' if 'train' in split else 'validation'
|
| 22 |
+
|
| 23 |
+
def __len__(self):
|
| 24 |
+
return len(self.img_names)
|
| 25 |
+
|
| 26 |
+
def cnt_instances(self):
|
| 27 |
+
ins_cnt = 0
|
| 28 |
+
for idx in tqdm(range(len(self))):
|
| 29 |
+
img_id = idx
|
| 30 |
+
img_name = self.img_names[img_id]
|
| 31 |
+
# ins_cnt += len(self.annots[img_name]['isValid'])
|
| 32 |
+
ins_cnt += len(self.annots[img_name]['shape'])
|
| 33 |
+
# tqdm.write(str(ins_cnt))
|
| 34 |
+
|
| 35 |
+
print(f'TOTAL: {ins_cnt}')
|
| 36 |
+
|
| 37 |
+
def get_raw_data(self, idx):
|
| 38 |
+
|
| 39 |
+
img_id = idx%len(self.img_names)
|
| 40 |
+
img_name = self.img_names[img_id]
|
| 41 |
+
|
| 42 |
+
annots = copy.deepcopy(self.annots[img_name])
|
| 43 |
+
img_path = os.path.join(self.dataset_path,self.split,img_name)
|
| 44 |
+
|
| 45 |
+
cam_intrinsics = torch.from_numpy(annots['cam_int']).unsqueeze(0)
|
| 46 |
+
cam_rot = torch.from_numpy(np.stack(annots['cam_rot']))
|
| 47 |
+
cam_trans = torch.from_numpy(np.stack(annots['cam_trans']))
|
| 48 |
+
|
| 49 |
+
betas = torch.from_numpy(np.stack(annots['shape']))
|
| 50 |
+
poses = torch.from_numpy(np.stack(annots['pose_world']))
|
| 51 |
+
transl = torch.from_numpy(np.stack(annots['trans_world']))
|
| 52 |
+
|
| 53 |
+
raw_data={'img_path': img_path,
|
| 54 |
+
'ds': 'bedlam',
|
| 55 |
+
'pnum': len(betas),
|
| 56 |
+
'betas': betas.float(),
|
| 57 |
+
'poses': poses.float(),
|
| 58 |
+
'transl': transl.float(),
|
| 59 |
+
'cam_rot': cam_rot.float(),
|
| 60 |
+
'cam_trans': cam_trans.float(),
|
| 61 |
+
'cam_intrinsics':cam_intrinsics.float(),
|
| 62 |
+
'3d_valid': True,
|
| 63 |
+
'age_valid': False,
|
| 64 |
+
'detect_all_people':True
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
if self.mode == 'eval':
|
| 68 |
+
raw_data['occ_level'] = torch.zeros(len(betas),dtype=int)
|
| 69 |
+
|
| 70 |
+
return raw_data
|
| 71 |
+
|
| 72 |
+
|
datasets/common.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data.dataset import Dataset
|
| 4 |
+
import os
|
| 5 |
+
from configs.paths import dataset_root
|
| 6 |
+
import copy
|
| 7 |
+
from .base import BASE
|
| 8 |
+
|
| 9 |
+
# dataset for inference
|
| 10 |
+
class COMMON(BASE):
|
| 11 |
+
def __init__(self, img_folder, **kwargs):
|
| 12 |
+
super(COMMON, self).__init__(**kwargs)
|
| 13 |
+
self.dataset_path = img_folder
|
| 14 |
+
self.img_names = sorted([img_name\
|
| 15 |
+
for img_name\
|
| 16 |
+
in os.listdir(self.dataset_path)\
|
| 17 |
+
if img_name.endswith('.png') or img_name.endswith('.jpg') or img_name.endswith('.jpeg')])
|
| 18 |
+
assert self.mode == 'infer'
|
| 19 |
+
|
| 20 |
+
def __len__(self):
|
| 21 |
+
return len(self.img_names)
|
| 22 |
+
|
| 23 |
+
def get_raw_data(self, idx):
|
| 24 |
+
img_id=idx%len(self.img_names)
|
| 25 |
+
img_name=self.img_names[img_id]
|
| 26 |
+
img_path=os.path.join(self.dataset_path,img_name)
|
| 27 |
+
raw_data={'img_path': img_path,
|
| 28 |
+
'img_name': img_name,
|
| 29 |
+
'ds': 'common'
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
return raw_data
|
| 33 |
+
|
| 34 |
+
|
datasets/multiple_datasets.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from torch.utils.data.dataset import Dataset
|
| 3 |
+
import numpy as np
|
| 4 |
+
from .agora import AGORA
|
| 5 |
+
from .bedlam import BEDLAM
|
| 6 |
+
|
| 7 |
+
datasets_dict = {'bedlam': BEDLAM, 'agora': AGORA}
|
| 8 |
+
|
| 9 |
+
class MultipleDatasets(Dataset):
|
| 10 |
+
def __init__(self, datasets_used, datasets_split = None, make_same_len = False, **kwargs):
|
| 11 |
+
if datasets_split is None:
|
| 12 |
+
self.dbs = [datasets_dict[ds](**kwargs) for ds in datasets_used]
|
| 13 |
+
else:
|
| 14 |
+
self.dbs = [datasets_dict[ds](split, **kwargs) for ds, split in zip(datasets_used, datasets_split)]
|
| 15 |
+
|
| 16 |
+
self.db_num = len(self.dbs)
|
| 17 |
+
self.max_db_data_num = max([len(db) for db in self.dbs])
|
| 18 |
+
self.db_len_cumsum = np.cumsum([len(db) for db in self.dbs])
|
| 19 |
+
self.make_same_len = make_same_len
|
| 20 |
+
self.human_model = self.dbs[0].human_model
|
| 21 |
+
|
| 22 |
+
def __len__(self):
|
| 23 |
+
# all dbs have the same length
|
| 24 |
+
if self.make_same_len:
|
| 25 |
+
return self.max_db_data_num * self.db_num
|
| 26 |
+
# each db has different length
|
| 27 |
+
else:
|
| 28 |
+
return sum([len(db) for db in self.dbs])
|
| 29 |
+
|
| 30 |
+
def __getitem__(self, index):
|
| 31 |
+
if self.make_same_len:
|
| 32 |
+
db_idx = index // self.max_db_data_num
|
| 33 |
+
data_idx = index % self.max_db_data_num
|
| 34 |
+
if data_idx >= len(self.dbs[db_idx]) * (self.max_db_data_num // len(self.dbs[db_idx])): # last batch: random sampling
|
| 35 |
+
data_idx = random.randint(0,len(self.dbs[db_idx])-1)
|
| 36 |
+
else: # before last batch: use modular
|
| 37 |
+
data_idx = data_idx % len(self.dbs[db_idx])
|
| 38 |
+
else:
|
| 39 |
+
for i in range(self.db_num):
|
| 40 |
+
if index < self.db_len_cumsum[i]:
|
| 41 |
+
db_idx = i
|
| 42 |
+
break
|
| 43 |
+
if db_idx == 0:
|
| 44 |
+
data_idx = index
|
| 45 |
+
else:
|
| 46 |
+
data_idx = index - self.db_len_cumsum[db_idx-1]
|
| 47 |
+
|
| 48 |
+
norm_img, meta_data = self.dbs[db_idx][data_idx]
|
| 49 |
+
return norm_img, meta_data
|
demo/img0.png
ADDED
|
Git LFS Details
|
demo/img1.jpeg
ADDED
|
demo/img2.jpg
ADDED
|
Git LFS Details
|
docs/fix_chumpy.md
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You may need to modify `chumpy` package to avoid errors.
|
| 2 |
+
|
| 3 |
+
* Comment line 11 in `${Your_Conda_Environment}/lib/python3.11/site-packages/chumpy/__init__.py`:
|
| 4 |
+
```
|
| 5 |
+
from .ch import *
|
| 6 |
+
from .logic import *
|
| 7 |
+
|
| 8 |
+
from .optimization import minimize
|
| 9 |
+
from . import extras
|
| 10 |
+
from . import testing
|
| 11 |
+
from .version import version as __version__
|
| 12 |
+
|
| 13 |
+
from .version import version as __version__
|
| 14 |
+
|
| 15 |
+
# from numpy import bool, int, float, complex, object, unicode, str, nan, inf
|
| 16 |
+
```
|
| 17 |
+
* Add *"inspect.getargspec = inspect.getfullargspec"* in `${Your_Conda_Environment}/lib/python3.11/site-packages/chumpy/ch.py` (line 25). Now it should look like:
|
| 18 |
+
```
|
| 19 |
+
#!/usr/bin/env python
|
| 20 |
+
# encoding: utf-8
|
| 21 |
+
"""
|
| 22 |
+
Author(s): Matthew Loper
|
| 23 |
+
|
| 24 |
+
See LICENCE.txt for licensing and contact information.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
__all__ = ['Ch', 'depends_on', 'MatVecMult', 'ChHandle', 'ChLambda']
|
| 29 |
+
|
| 30 |
+
import os, sys, time
|
| 31 |
+
import inspect
|
| 32 |
+
import scipy.sparse as sp
|
| 33 |
+
import numpy as np
|
| 34 |
+
import numbers
|
| 35 |
+
import weakref
|
| 36 |
+
import copy as external_copy
|
| 37 |
+
from functools import wraps
|
| 38 |
+
from scipy.sparse.linalg.interface import LinearOperator
|
| 39 |
+
from .utils import row, col, timer, convert_inputs_to_sparse_if_necessary
|
| 40 |
+
import collections
|
| 41 |
+
from copy import deepcopy
|
| 42 |
+
from functools import reduce
|
| 43 |
+
inspect.getargspec = inspect.getfullargspec
|
| 44 |
+
```
|
engines/__init__.py
ADDED
|
File without changes
|
engines/engine.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from accelerate import Accelerator
|
| 2 |
+
from tqdm.auto import tqdm
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from datasets.multiple_datasets import MultipleDatasets, datasets_dict
|
| 6 |
+
from datasets.common import COMMON
|
| 7 |
+
from transformers import get_scheduler
|
| 8 |
+
from safetensors.torch import load_file
|
| 9 |
+
import os
|
| 10 |
+
import re
|
| 11 |
+
import time
|
| 12 |
+
import datetime
|
| 13 |
+
from models import build_sat_model
|
| 14 |
+
from .funcs.eval_funcs import *
|
| 15 |
+
from .funcs.infer_funcs import inference
|
| 16 |
+
from utils import misc
|
| 17 |
+
from utils.misc import get_world_size
|
| 18 |
+
import torch.multiprocessing
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Engine():
|
| 23 |
+
def __init__(self, args, mode='train'):
|
| 24 |
+
self.exp_name = args.exp_name
|
| 25 |
+
self.mode = mode
|
| 26 |
+
assert mode in ['train','eval','infer']
|
| 27 |
+
self.conf_thresh = args.conf_thresh
|
| 28 |
+
self.eval_func_maps = {'agora_validation': evaluate_agora,
|
| 29 |
+
'bedlam_validation_6fps': evaluate_agora,
|
| 30 |
+
'agora_test': test_agora}
|
| 31 |
+
self.inference_func = inference
|
| 32 |
+
|
| 33 |
+
if self.mode == 'train':
|
| 34 |
+
self.output_dir = os.path.join('./outputs')
|
| 35 |
+
self.log_dir = os.path.join(self.output_dir,'logs')
|
| 36 |
+
self.ckpt_dir = os.path.join(self.output_dir,'ckpts')
|
| 37 |
+
self.distributed_eval = args.distributed_eval
|
| 38 |
+
self.eval_vis_num = args.eval_vis_num
|
| 39 |
+
elif self.mode == 'eval':
|
| 40 |
+
self.output_dir = os.path.join('./results')
|
| 41 |
+
self.distributed_eval = args.distributed_eval
|
| 42 |
+
self.eval_vis_num = args.eval_vis_num
|
| 43 |
+
elif self.mode == 'infer':
|
| 44 |
+
output_dir = getattr(args, 'output_dir', None)
|
| 45 |
+
if output_dir is not None:
|
| 46 |
+
self.output_dir = output_dir
|
| 47 |
+
else:
|
| 48 |
+
now = datetime.datetime.now()
|
| 49 |
+
timestamp = now.strftime("%Y%m%d_%H%M%S")
|
| 50 |
+
self.output_dir = os.path.join('./results',f'{self.exp_name}_infer_{timestamp}')
|
| 51 |
+
self.distributed_infer = args.distributed_infer
|
| 52 |
+
|
| 53 |
+
self.prepare_accelerator()
|
| 54 |
+
self.prepare_models(args)
|
| 55 |
+
self.prepare_datas(args)
|
| 56 |
+
if self.mode == 'train':
|
| 57 |
+
self.prepare_training(args)
|
| 58 |
+
|
| 59 |
+
total_cnt = sum(p.numel() for p in self.model.parameters())
|
| 60 |
+
trainable_cnt = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
| 61 |
+
self.accelerator.print(f'Initialization finished.\n{trainable_cnt} trainable parameters({total_cnt} total).')
|
| 62 |
+
|
| 63 |
+
def prepare_accelerator(self):
|
| 64 |
+
if self.mode == 'train':
|
| 65 |
+
self.accelerator = Accelerator(
|
| 66 |
+
log_with="tensorboard",
|
| 67 |
+
project_dir=os.path.join(self.log_dir)
|
| 68 |
+
)
|
| 69 |
+
if self.accelerator.is_main_process:
|
| 70 |
+
os.makedirs(self.log_dir, exist_ok=True)
|
| 71 |
+
os.makedirs(os.path.join(self.ckpt_dir,self.exp_name),exist_ok=True)
|
| 72 |
+
self.accelerator.init_trackers(self.exp_name)
|
| 73 |
+
else:
|
| 74 |
+
self.accelerator = Accelerator()
|
| 75 |
+
if self.accelerator.is_main_process:
|
| 76 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
| 77 |
+
|
| 78 |
+
def prepare_models(self, args):
|
| 79 |
+
# load model and criterion
|
| 80 |
+
self.accelerator.print('Preparing models...')
|
| 81 |
+
self.unwrapped_model, self.criterion = build_sat_model(args, set_criterion = (self.mode == 'train'))
|
| 82 |
+
if self.criterion is not None:
|
| 83 |
+
self.weight_dict = self.criterion.weight_dict
|
| 84 |
+
# load weights
|
| 85 |
+
if args.pretrain:
|
| 86 |
+
self.accelerator.print(f'Loading pretrained weights: {args.pretrain_path}')
|
| 87 |
+
state_dict = torch.load(args.pretrain_path)
|
| 88 |
+
self.unwrapped_model.load_state_dict(state_dict,strict=False)
|
| 89 |
+
|
| 90 |
+
# to gpu
|
| 91 |
+
self.model = self.accelerator.prepare(self.unwrapped_model)
|
| 92 |
+
|
| 93 |
+
def prepare_datas(self, args):
|
| 94 |
+
# load dataset and dataloader
|
| 95 |
+
if self.mode == 'train':
|
| 96 |
+
self.accelerator.print('Loading training datasets:\n',
|
| 97 |
+
[f'{d}_{s}' for d,s in zip(args.train_datasets_used, args.train_datasets_split)])
|
| 98 |
+
self.train_batch_size = args.train_batch_size
|
| 99 |
+
train_dataset = MultipleDatasets(args.train_datasets_used, args.train_datasets_split,
|
| 100 |
+
make_same_len=False, input_size=args.input_size, aug=True,
|
| 101 |
+
mode = 'train', sat_cfg=args.sat_cfg,
|
| 102 |
+
aug_cfg=args.aug_cfg)
|
| 103 |
+
self.train_dataloader = DataLoader(dataset=train_dataset, batch_size=self.train_batch_size,
|
| 104 |
+
shuffle=True,collate_fn=misc.collate_fn,
|
| 105 |
+
num_workers=args.train_num_workers,pin_memory=True)
|
| 106 |
+
self.train_dataloader = self.accelerator.prepare(self.train_dataloader)
|
| 107 |
+
|
| 108 |
+
if self.mode != 'infer':
|
| 109 |
+
self.accelerator.print('Loading evaluation datasets:',
|
| 110 |
+
[f'{d}_{s}' for d,s in zip(args.eval_datasets_used, args.eval_datasets_split)])
|
| 111 |
+
self.eval_batch_size = args.eval_batch_size
|
| 112 |
+
eval_ds = {f'{ds}_{split}': datasets_dict[ds](split = split,
|
| 113 |
+
mode = 'eval',
|
| 114 |
+
input_size = args.input_size,
|
| 115 |
+
aug = False,
|
| 116 |
+
sat_cfg=args.sat_cfg)\
|
| 117 |
+
for (ds, split) in zip(args.eval_datasets_used, args.eval_datasets_split)}
|
| 118 |
+
self.eval_dataloaders = {k: DataLoader(dataset=v, batch_size=self.eval_batch_size,
|
| 119 |
+
shuffle=False,collate_fn=misc.collate_fn,
|
| 120 |
+
num_workers=args.eval_num_workers,pin_memory=True)\
|
| 121 |
+
for (k,v) in eval_ds.items()}
|
| 122 |
+
if self.distributed_eval:
|
| 123 |
+
for (k,v) in self.eval_dataloaders.items():
|
| 124 |
+
self.eval_dataloaders.update({k: self.accelerator.prepare(v)})
|
| 125 |
+
|
| 126 |
+
else:
|
| 127 |
+
img_folder = args.input_dir
|
| 128 |
+
self.accelerator.print(f'Loading inference images from {img_folder}')
|
| 129 |
+
self.infer_batch_size = args.infer_batch_size
|
| 130 |
+
infer_ds = COMMON(img_folder = img_folder, input_size=args.input_size,aug=False,
|
| 131 |
+
mode = 'infer', sat_cfg=args.sat_cfg)
|
| 132 |
+
self.infer_dataloader = DataLoader(dataset=infer_ds, batch_size=self.infer_batch_size,
|
| 133 |
+
shuffle=False,collate_fn=misc.collate_fn,
|
| 134 |
+
num_workers=args.infer_num_workers,pin_memory=True)
|
| 135 |
+
|
| 136 |
+
if self.distributed_infer:
|
| 137 |
+
self.infer_dataloader = self.accelerator.prepare(self.infer_dataloader)
|
| 138 |
+
|
| 139 |
+
def prepare_training(self, args):
|
| 140 |
+
self.start_epoch = 0
|
| 141 |
+
self.num_epochs = args.num_epochs
|
| 142 |
+
self.global_step = 0
|
| 143 |
+
if hasattr(args, 'sat_gt_epoch'):
|
| 144 |
+
self.sat_gt_epoch = args.sat_gt_epoch
|
| 145 |
+
self.accelerator.print(f'Use GT for the first {self.sat_gt_epoch} epoch(s)...')
|
| 146 |
+
else:
|
| 147 |
+
self.sat_gt_epoch = -1
|
| 148 |
+
self.save_and_eval_epoch = args.save_and_eval_epoch
|
| 149 |
+
self.least_eval_epoch = args.least_eval_epoch
|
| 150 |
+
|
| 151 |
+
self.detach_j3ds = args.detach_j3ds
|
| 152 |
+
|
| 153 |
+
self.accelerator.print('Preparing optimizer and lr_scheduler...')
|
| 154 |
+
param_dicts = [
|
| 155 |
+
{
|
| 156 |
+
"params":
|
| 157 |
+
[p for n, p in self.unwrapped_model.named_parameters()
|
| 158 |
+
if not misc.match_name_keywords(n, args.lr_encoder_names) and p.requires_grad],
|
| 159 |
+
"lr": args.lr,
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"params":
|
| 163 |
+
[p for n, p in self.unwrapped_model.named_parameters()
|
| 164 |
+
if misc.match_name_keywords(n, args.lr_encoder_names) and p.requires_grad],
|
| 165 |
+
"lr": args.lr_encoder,
|
| 166 |
+
}
|
| 167 |
+
]
|
| 168 |
+
|
| 169 |
+
# optimizer
|
| 170 |
+
if args.optimizer == 'adamw':
|
| 171 |
+
self.optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
| 172 |
+
weight_decay=args.weight_decay)
|
| 173 |
+
else:
|
| 174 |
+
raise NotImplementedError
|
| 175 |
+
|
| 176 |
+
# lr_scheduler
|
| 177 |
+
if args.lr_scheduler == 'cosine':
|
| 178 |
+
self.lr_scheduler = get_scheduler(name="cosine", optimizer=self.optimizer,
|
| 179 |
+
num_warmup_steps=args.num_warmup_steps,
|
| 180 |
+
num_training_steps=get_world_size() * self.num_epochs * len(self.train_dataloader))
|
| 181 |
+
elif args.lr_scheduler == 'multistep':
|
| 182 |
+
self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, args.milestones, gamma=args.gamma)
|
| 183 |
+
else:
|
| 184 |
+
raise NotImplementedError
|
| 185 |
+
|
| 186 |
+
self.optimizer, self.lr_scheduler = self.accelerator.prepare(self.optimizer, self.lr_scheduler)
|
| 187 |
+
|
| 188 |
+
# resume
|
| 189 |
+
if args.resume: #load model, optimizer, lr_scheduler and random_state
|
| 190 |
+
if hasattr(args, 'ckpt_epoch'):
|
| 191 |
+
self.load_ckpt(args.ckpt_epoch,args.ckpt_step)
|
| 192 |
+
else:
|
| 193 |
+
self.accelerator.print('Auto resume from latest ckpt...')
|
| 194 |
+
epoch, step = -1, -1
|
| 195 |
+
pattern = re.compile(r'epoch_(\d+)_step_(\d+)')
|
| 196 |
+
for folder_name in os.listdir(os.path.join(self.output_dir,'ckpts',self.exp_name)):
|
| 197 |
+
match = pattern.match(folder_name)
|
| 198 |
+
if match:
|
| 199 |
+
i, j = int(match.group(1)), int(match.group(2))
|
| 200 |
+
if i > epoch:
|
| 201 |
+
epoch, step = i, j
|
| 202 |
+
if epoch >= 0:
|
| 203 |
+
self.load_ckpt(epoch, step)
|
| 204 |
+
else:
|
| 205 |
+
self.accelerator.print('No existing ckpts! Train from scratch.')
|
| 206 |
+
|
| 207 |
+
def load_ckpt(self, epoch, step):
|
| 208 |
+
self.accelerator.print(f'Loading checkpoint: epoch_{epoch}_step_{step}')
|
| 209 |
+
ckpts_save_path = os.path.join(self.output_dir,'ckpts',self.exp_name, f'epoch_{epoch}_step_{step}')
|
| 210 |
+
self.start_epoch = epoch + 1
|
| 211 |
+
self.global_step = step + 1
|
| 212 |
+
self.accelerator.load_state(ckpts_save_path)
|
| 213 |
+
|
| 214 |
+
def train(self):
|
| 215 |
+
# torch.autograd.set_detect_anomaly(True)
|
| 216 |
+
self.accelerator.print('Start training!')
|
| 217 |
+
for epoch in range(self.start_epoch, self.num_epochs):
|
| 218 |
+
torch.cuda.empty_cache()
|
| 219 |
+
progress_bar = tqdm(total=len(self.train_dataloader), disable=not self.accelerator.is_local_main_process)
|
| 220 |
+
progress_bar.set_description(f"Epoch {epoch}")
|
| 221 |
+
|
| 222 |
+
self.model.train()
|
| 223 |
+
self.criterion.train()
|
| 224 |
+
|
| 225 |
+
sat_use_gt = (epoch < self.sat_gt_epoch)
|
| 226 |
+
|
| 227 |
+
for step, (samples,targets) in enumerate(self.train_dataloader):
|
| 228 |
+
|
| 229 |
+
outputs = self.model(samples, targets, sat_use_gt = sat_use_gt, detach_j3ds = self.detach_j3ds)
|
| 230 |
+
loss_dict = self.criterion(outputs, targets)
|
| 231 |
+
|
| 232 |
+
loss = sum(loss_dict[k] * self.weight_dict[k] for k in loss_dict.keys())
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
self.accelerator.backward(loss)
|
| 236 |
+
|
| 237 |
+
if self.accelerator.sync_gradients:
|
| 238 |
+
self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 239 |
+
|
| 240 |
+
self.optimizer.step()
|
| 241 |
+
|
| 242 |
+
self.lr_scheduler.step()
|
| 243 |
+
self.optimizer.zero_grad()
|
| 244 |
+
|
| 245 |
+
reduced_dict = self.accelerator.reduce(loss_dict,reduction='mean')
|
| 246 |
+
simplified_logs = {k: v.item() for k, v in reduced_dict.items() if '.' not in k}
|
| 247 |
+
|
| 248 |
+
# logs.update({"lr": self.lr_scheduler.get_last_lr()[0], "step": self.global_step})
|
| 249 |
+
if self.accelerator.is_main_process:
|
| 250 |
+
tqdm.write(f'[{epoch}-{step+1}/{len(self.train_dataloader)}]: ' + str(simplified_logs))
|
| 251 |
+
|
| 252 |
+
if step % 10 == 0:
|
| 253 |
+
self.accelerator.log({('train/'+k):v for k,v in simplified_logs.items()},
|
| 254 |
+
step=self.global_step)
|
| 255 |
+
|
| 256 |
+
progress_bar.update(1)
|
| 257 |
+
progress_bar.set_postfix(**{"lr": self.lr_scheduler.get_last_lr()[0], "step": self.global_step})
|
| 258 |
+
|
| 259 |
+
self.global_step += 1
|
| 260 |
+
self.accelerator.wait_for_everyone()
|
| 261 |
+
|
| 262 |
+
# self.lr_scheduler.step()
|
| 263 |
+
|
| 264 |
+
if epoch % self.save_and_eval_epoch == 0 or epoch == self.num_epochs-1:
|
| 265 |
+
self.save_and_eval(epoch, save_ckpt=True)
|
| 266 |
+
|
| 267 |
+
self.accelerator.end_training()
|
| 268 |
+
|
| 269 |
+
def eval(self, results_save_path = None, epoch = -1):
|
| 270 |
+
if results_save_path is None:
|
| 271 |
+
results_save_path = os.path.join(self.output_dir,self.exp_name,'evaluation')
|
| 272 |
+
# preparing
|
| 273 |
+
self.model.eval()
|
| 274 |
+
unwrapped_model = self.unwrapped_model # self.accelerator.unwrap_model(self.model)
|
| 275 |
+
if self.accelerator.is_main_process:
|
| 276 |
+
os.makedirs(results_save_path,exist_ok=True)
|
| 277 |
+
# evaluate
|
| 278 |
+
for i, (key, eval_dataloader) in enumerate(self.eval_dataloaders.items()):
|
| 279 |
+
assert key in self.eval_func_maps
|
| 280 |
+
img_cnt = len(eval_dataloader) * self.eval_batch_size
|
| 281 |
+
if self.distributed_eval:
|
| 282 |
+
img_cnt *= self.accelerator.num_processes
|
| 283 |
+
self.accelerator.print(f'Evaluate on {key}: {img_cnt} images')
|
| 284 |
+
self.accelerator.print('Using following threshold(s): ', self.conf_thresh)
|
| 285 |
+
conf_thresh = self.conf_thresh if 'agora' in key or 'bedlam' in key else [0.2]
|
| 286 |
+
for thresh in conf_thresh:
|
| 287 |
+
if self.accelerator.is_main_process or self.distributed_eval:
|
| 288 |
+
error_dict = self.eval_func_maps[key](model = unwrapped_model,
|
| 289 |
+
eval_dataloader = eval_dataloader,
|
| 290 |
+
conf_thresh = thresh,
|
| 291 |
+
vis_step = img_cnt // self.eval_vis_num,
|
| 292 |
+
results_save_path = os.path.join(results_save_path,key,f'thresh_{thresh}'),
|
| 293 |
+
distributed = self.distributed_eval,
|
| 294 |
+
accelerator = self.accelerator,
|
| 295 |
+
vis=True)
|
| 296 |
+
if isinstance(error_dict,dict) and self.mode == 'train':
|
| 297 |
+
log_dict = flatten_dict(error_dict)
|
| 298 |
+
self.accelerator.log({(f'{key}_thresh_{thresh}/'+k):v for k,v in log_dict.items()}, step=epoch)
|
| 299 |
+
|
| 300 |
+
self.accelerator.print(f'thresh_{thresh}: ',error_dict)
|
| 301 |
+
self.accelerator.wait_for_everyone()
|
| 302 |
+
|
| 303 |
+
def save_and_eval(self, epoch, save_ckpt=False):
|
| 304 |
+
torch.cuda.empty_cache()
|
| 305 |
+
# save current state and model
|
| 306 |
+
if self.accelerator.is_main_process and save_ckpt:
|
| 307 |
+
ckpts_save_path = os.path.join(self.output_dir,'ckpts',self.exp_name, f'epoch_{epoch}_step_{self.global_step-1}')
|
| 308 |
+
os.makedirs(ckpts_save_path,exist_ok=True)
|
| 309 |
+
self.accelerator.save_state(ckpts_save_path, safe_serialization=False)
|
| 310 |
+
self.accelerator.wait_for_everyone()
|
| 311 |
+
|
| 312 |
+
if epoch < self.least_eval_epoch:
|
| 313 |
+
return
|
| 314 |
+
results_save_path = os.path.join(self.output_dir,'results',self.exp_name, f'epoch_{epoch}_step_{self.global_step-1}')
|
| 315 |
+
self.eval(results_save_path, epoch=epoch)
|
| 316 |
+
|
| 317 |
+
def infer(self):
|
| 318 |
+
self.model.eval()
|
| 319 |
+
# unwrapped_model = self.accelerator.unwrap_model(self.model)
|
| 320 |
+
unwrapped_model = self.unwrapped_model
|
| 321 |
+
|
| 322 |
+
results_save_path = self.output_dir
|
| 323 |
+
if self.accelerator.is_main_process:
|
| 324 |
+
os.makedirs(results_save_path,exist_ok=True)
|
| 325 |
+
|
| 326 |
+
self.accelerator.print('Using following threshold(s): ', self.conf_thresh)
|
| 327 |
+
for thresh in self.conf_thresh:
|
| 328 |
+
if self.accelerator.is_main_process or self.distributed_infer:
|
| 329 |
+
self.inference_func(model = unwrapped_model,
|
| 330 |
+
infer_dataloader = self.infer_dataloader,
|
| 331 |
+
conf_thresh = thresh,
|
| 332 |
+
results_save_path = os.path.join(results_save_path,f'thresh_{thresh}'),
|
| 333 |
+
distributed = self.distributed_infer,
|
| 334 |
+
accelerator = self.accelerator)
|
| 335 |
+
self.accelerator.wait_for_everyone()
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def flatten_dict(d, parent_key='', sep='-'):
|
| 339 |
+
items = []
|
| 340 |
+
for k, v in d.items():
|
| 341 |
+
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
| 342 |
+
if isinstance(v, dict):
|
| 343 |
+
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
| 344 |
+
else:
|
| 345 |
+
items.append((new_key, v))
|
| 346 |
+
return dict(items)
|
| 347 |
+
|
engines/funcs/__init__.py
ADDED
|
File without changes
|
engines/funcs/eval_funcs.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from tqdm.auto import tqdm
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from utils.evaluation import cal_3d_position_error, match_2d_greedy, get_matching_dict, compute_prf1, vectorize_distance, calculate_iou
|
| 6 |
+
from utils.transforms import pelvis_align, root_align, unNormalize
|
| 7 |
+
from utils.visualization import tensor_to_BGR, pad_img
|
| 8 |
+
from utils.visualization import vis_meshes_img, vis_boxes, vis_sat, vis_scale_img, get_colors_rgb
|
| 9 |
+
from utils.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh
|
| 10 |
+
from utils.constants import human36_eval_joint, J24_TO_H36M, H36M_TO_MPII
|
| 11 |
+
import time
|
| 12 |
+
import datetime
|
| 13 |
+
import scipy.io as sio
|
| 14 |
+
import cv2
|
| 15 |
+
import zipfile
|
| 16 |
+
import pickle
|
| 17 |
+
|
| 18 |
+
# for agora evaluation
|
| 19 |
+
def select_and_align(smpl_joints, smpl_verts, body_verts_ind):
|
| 20 |
+
joints = smpl_joints[:24, :]
|
| 21 |
+
verts = smpl_verts[body_verts_ind, :]
|
| 22 |
+
assert len(verts.shape) == 2
|
| 23 |
+
verts = pelvis_align(joints, verts)
|
| 24 |
+
joints = pelvis_align(joints)
|
| 25 |
+
return joints, verts
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Modified from agora_evaluation
|
| 29 |
+
def evaluate_agora(model, eval_dataloader, conf_thresh,
|
| 30 |
+
vis = True, vis_step = 40, results_save_path = None,
|
| 31 |
+
distributed = False, accelerator = None):
|
| 32 |
+
assert results_save_path is not None
|
| 33 |
+
assert accelerator is not None
|
| 34 |
+
num_processes = accelerator.num_processes
|
| 35 |
+
|
| 36 |
+
has_kid = ('train' in eval_dataloader.dataset.split and eval_dataloader.dataset.ds_name == 'agora')
|
| 37 |
+
|
| 38 |
+
os.makedirs(results_save_path,exist_ok=True)
|
| 39 |
+
if vis:
|
| 40 |
+
imgs_save_dir = os.path.join(results_save_path, 'imgs')
|
| 41 |
+
os.makedirs(imgs_save_dir, exist_ok = True)
|
| 42 |
+
|
| 43 |
+
step = 0
|
| 44 |
+
total_miss_count = 0
|
| 45 |
+
total_count = 0
|
| 46 |
+
total_fp = 0
|
| 47 |
+
mve, mpjpe = [0.], [0.]
|
| 48 |
+
|
| 49 |
+
if has_kid:
|
| 50 |
+
kid_total_miss_count = 0
|
| 51 |
+
kid_total_count = 0
|
| 52 |
+
kid_mve, kid_mpjpe = [0.], [0.]
|
| 53 |
+
|
| 54 |
+
cur_device = next(model.parameters()).device
|
| 55 |
+
smpl_layer = model.human_model
|
| 56 |
+
body_verts_ind = smpl_layer.body_vertex_idx
|
| 57 |
+
|
| 58 |
+
progress_bar = tqdm(total=len(eval_dataloader), disable=not accelerator.is_local_main_process)
|
| 59 |
+
progress_bar.set_description('evaluate')
|
| 60 |
+
for itr, (samples, targets) in enumerate(eval_dataloader):
|
| 61 |
+
samples=[sample.to(device = cur_device, non_blocking = True) for sample in samples]
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
outputs = model(samples, targets)
|
| 64 |
+
bs = len(targets)
|
| 65 |
+
for idx in range(bs):
|
| 66 |
+
#gt
|
| 67 |
+
gt_j2ds = targets[idx]['j2ds'].cpu().numpy()[:,:24,:]
|
| 68 |
+
gt_j3ds = targets[idx]['j3ds'].cpu().numpy()[:,:24,:]
|
| 69 |
+
gt_verts = targets[idx]['verts'].cpu().numpy()
|
| 70 |
+
|
| 71 |
+
#pred
|
| 72 |
+
select_queries_idx = torch.where(outputs['pred_confs'][idx] > conf_thresh)[0]
|
| 73 |
+
pred_j2ds = outputs['pred_j2ds'][idx][select_queries_idx].detach().cpu().numpy()[:,:24,:]
|
| 74 |
+
pred_j3ds = outputs['pred_j3ds'][idx][select_queries_idx].detach().cpu().numpy()[:,:24,:]
|
| 75 |
+
pred_verts = outputs['pred_verts'][idx][select_queries_idx].detach().cpu().numpy()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
matched_verts_idx = []
|
| 79 |
+
assert len(gt_j2ds.shape) == 3 and len(pred_j2ds.shape) == 3
|
| 80 |
+
#matching
|
| 81 |
+
greedy_match = match_2d_greedy(pred_j2ds, gt_j2ds) # tuples are (idx_pred_kps, idx_gt_kps)
|
| 82 |
+
matchDict, falsePositive_count = get_matching_dict(greedy_match)
|
| 83 |
+
|
| 84 |
+
#align with matching result
|
| 85 |
+
gt_verts_list, pred_verts_list, gt_joints_list, pred_joints_list = [], [], [], []
|
| 86 |
+
gtIdxs = np.arange(len(gt_j3ds))
|
| 87 |
+
miss_flag = []
|
| 88 |
+
for gtIdx in gtIdxs:
|
| 89 |
+
gt_verts_list.append(gt_verts[gtIdx])
|
| 90 |
+
gt_joints_list.append(gt_j3ds[gtIdx])
|
| 91 |
+
if matchDict[str(gtIdx)] == 'miss' or matchDict[str(
|
| 92 |
+
gtIdx)] == 'invalid':
|
| 93 |
+
miss_flag.append(1)
|
| 94 |
+
pred_verts_list.append([])
|
| 95 |
+
pred_joints_list.append([])
|
| 96 |
+
else:
|
| 97 |
+
miss_flag.append(0)
|
| 98 |
+
pred_joints_list.append(pred_j3ds[matchDict[str(gtIdx)]])
|
| 99 |
+
pred_verts_list.append(pred_verts[matchDict[str(gtIdx)]])
|
| 100 |
+
matched_verts_idx.append(matchDict[str(gtIdx)])
|
| 101 |
+
|
| 102 |
+
if has_kid:
|
| 103 |
+
gt_kid_list = targets[idx]['kid']
|
| 104 |
+
|
| 105 |
+
#calculating 3d errors
|
| 106 |
+
for i, (gt3d, pred) in enumerate(zip(gt_joints_list, pred_joints_list)):
|
| 107 |
+
total_count += 1
|
| 108 |
+
if has_kid and gt_kid_list[i]:
|
| 109 |
+
kid_total_count += 1
|
| 110 |
+
|
| 111 |
+
# Get corresponding ground truth and predicted 3d joints and verts
|
| 112 |
+
if miss_flag[i] == 1:
|
| 113 |
+
total_miss_count += 1
|
| 114 |
+
if has_kid and gt_kid_list[i]:
|
| 115 |
+
kid_total_miss_count += 1
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
gt3d = gt3d.reshape(-1, 3)
|
| 119 |
+
pred3d = pred.reshape(-1, 3)
|
| 120 |
+
gt3d_verts = gt_verts_list[i].reshape(-1, 3)
|
| 121 |
+
pred3d_verts = pred_verts_list[i].reshape(-1, 3)
|
| 122 |
+
|
| 123 |
+
gt3d, gt3d_verts = select_and_align(gt3d, gt3d_verts, body_verts_ind)
|
| 124 |
+
pred3d, pred3d_verts = select_and_align(pred3d, pred3d_verts, body_verts_ind)
|
| 125 |
+
|
| 126 |
+
#joints
|
| 127 |
+
error_j, pa_error_j = cal_3d_position_error(pred3d, gt3d)
|
| 128 |
+
mpjpe.append(error_j)
|
| 129 |
+
if has_kid and gt_kid_list[i]:
|
| 130 |
+
kid_mpjpe.append(error_j)
|
| 131 |
+
#vertices
|
| 132 |
+
error_v,pa_error_v = cal_3d_position_error(pred3d_verts, gt3d_verts)
|
| 133 |
+
mve.append(error_v)
|
| 134 |
+
if has_kid and gt_kid_list[i]:
|
| 135 |
+
kid_mve.append(error_v)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
#counting
|
| 139 |
+
step += 1
|
| 140 |
+
total_fp += falsePositive_count
|
| 141 |
+
|
| 142 |
+
img_idx = step + accelerator.process_index*len(eval_dataloader)*bs
|
| 143 |
+
|
| 144 |
+
if vis and (img_idx%vis_step == 0):
|
| 145 |
+
img_name = targets[idx]['img_path'].split('/')[-1].split('.')[0]
|
| 146 |
+
ori_img = tensor_to_BGR(unNormalize(samples[idx]).cpu())
|
| 147 |
+
|
| 148 |
+
# render mesh
|
| 149 |
+
colors = [(1.0, 1.0, 0.9)] * len(gt_verts)
|
| 150 |
+
gt_mesh_img = vis_meshes_img(img = ori_img.copy(),
|
| 151 |
+
verts = gt_verts,
|
| 152 |
+
smpl_faces = smpl_layer.faces,
|
| 153 |
+
cam_intrinsics = targets[idx]['cam_intrinsics'].reshape(3,3).detach().cpu(),
|
| 154 |
+
colors = colors)
|
| 155 |
+
|
| 156 |
+
colors = [(1.0, 0.6, 0.6)] * len(pred_verts)
|
| 157 |
+
for i in matched_verts_idx:
|
| 158 |
+
colors[i] = (0.7, 1.0, 0.4)
|
| 159 |
+
|
| 160 |
+
# colors = get_colors_rgb(len(pred_verts))
|
| 161 |
+
pred_mesh_img = vis_meshes_img(img = ori_img.copy(),
|
| 162 |
+
verts = pred_verts,
|
| 163 |
+
smpl_faces = smpl_layer.faces,
|
| 164 |
+
cam_intrinsics = outputs['pred_intrinsics'][idx].reshape(3,3).detach().cpu(),
|
| 165 |
+
colors = colors,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
if 'enc_outputs' not in outputs:
|
| 170 |
+
pred_scale_img = np.zeros_like(pred_mesh_img)
|
| 171 |
+
else:
|
| 172 |
+
enc_out = outputs['enc_outputs']
|
| 173 |
+
h, w = enc_out['hw'][idx]
|
| 174 |
+
flatten_map = enc_out['scale_map'].split(enc_out['lens'])[idx].detach().cpu()
|
| 175 |
+
|
| 176 |
+
ys = enc_out['pos_y'].split(enc_out['lens'])[idx]
|
| 177 |
+
xs = enc_out['pos_x'].split(enc_out['lens'])[idx]
|
| 178 |
+
scale_map = torch.zeros((h,w,2))
|
| 179 |
+
scale_map[ys,xs] = flatten_map
|
| 180 |
+
|
| 181 |
+
pred_scale_img = vis_scale_img(img = ori_img.copy(),
|
| 182 |
+
scale_map = scale_map,
|
| 183 |
+
conf_thresh = model.sat_cfg['conf_thresh'],
|
| 184 |
+
patch_size=28)
|
| 185 |
+
|
| 186 |
+
pred_boxes = outputs['pred_boxes'][idx][select_queries_idx].detach().cpu()
|
| 187 |
+
pred_boxes = box_cxcywh_to_xyxy(pred_boxes) * model.input_size
|
| 188 |
+
pred_box_img = vis_boxes(ori_img.copy(), pred_boxes, color = (255,0,255))
|
| 189 |
+
|
| 190 |
+
# sat
|
| 191 |
+
sat_img = vis_sat(ori_img.copy(),
|
| 192 |
+
input_size = model.input_size,
|
| 193 |
+
patch_size = 14,
|
| 194 |
+
sat_dict = outputs['sat'],
|
| 195 |
+
bid = idx)
|
| 196 |
+
|
| 197 |
+
ori_img = pad_img(ori_img, model.input_size)
|
| 198 |
+
|
| 199 |
+
full_img = np.vstack([np.hstack([ori_img, sat_img]),
|
| 200 |
+
np.hstack([pred_scale_img, pred_box_img]),
|
| 201 |
+
np.hstack([gt_mesh_img, pred_mesh_img])])
|
| 202 |
+
|
| 203 |
+
cv2.imwrite(os.path.join(imgs_save_dir, f'{img_idx}_{img_name}.png'), full_img)
|
| 204 |
+
|
| 205 |
+
progress_bar.update(1)
|
| 206 |
+
|
| 207 |
+
if distributed:
|
| 208 |
+
mve = accelerator.gather_for_metrics(mve)
|
| 209 |
+
mpjpe = accelerator.gather_for_metrics(mpjpe)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
total_miss_count = sum(accelerator.gather_for_metrics([total_miss_count]))
|
| 213 |
+
total_count = sum(accelerator.gather_for_metrics([total_count]))
|
| 214 |
+
total_fp = sum(accelerator.gather_for_metrics([total_fp]))
|
| 215 |
+
|
| 216 |
+
if has_kid:
|
| 217 |
+
kid_mve = accelerator.gather_for_metrics(kid_mve)
|
| 218 |
+
kid_mpjpe = accelerator.gather_for_metrics(kid_mpjpe)
|
| 219 |
+
kid_total_miss_count = sum(accelerator.gather_for_metrics([kid_total_miss_count]))
|
| 220 |
+
kid_total_count = sum(accelerator.gather_for_metrics([kid_total_count]))
|
| 221 |
+
|
| 222 |
+
if len(mpjpe) <= num_processes:
|
| 223 |
+
return "Failed to evaluate. Keep training!"
|
| 224 |
+
if has_kid and len(kid_mpjpe) <= num_processes:
|
| 225 |
+
return "Failed to evaluate. Keep training!"
|
| 226 |
+
|
| 227 |
+
precision, recall, f1 = compute_prf1(total_count,total_miss_count,total_fp)
|
| 228 |
+
error_dict = {}
|
| 229 |
+
error_dict['precision'] = precision
|
| 230 |
+
error_dict['recall'] = recall
|
| 231 |
+
error_dict['f1'] = f1
|
| 232 |
+
|
| 233 |
+
error_dict['MPJPE'] = round(sum(mpjpe)/(len(mpjpe)-num_processes), 1)
|
| 234 |
+
error_dict['NMJE'] = round(error_dict['MPJPE'] / (f1), 1)
|
| 235 |
+
error_dict['MVE'] = round(sum(mve)/(len(mve)-num_processes), 1)
|
| 236 |
+
error_dict['NMVE'] = round(error_dict['MVE'] / (f1), 1)
|
| 237 |
+
|
| 238 |
+
if has_kid:
|
| 239 |
+
kid_precision, kid_recall, kid_f1 = compute_prf1(kid_total_count,kid_total_miss_count,total_fp)
|
| 240 |
+
error_dict['kid_precision'] = kid_precision
|
| 241 |
+
error_dict['kid_recall'] = kid_recall
|
| 242 |
+
error_dict['kid_f1'] = kid_f1
|
| 243 |
+
|
| 244 |
+
error_dict['kid-MPJPE'] = round(sum(kid_mpjpe)/(len(kid_mpjpe)-num_processes), 1)
|
| 245 |
+
error_dict['kid-NMJE'] = round(error_dict['kid-MPJPE'] / (kid_f1), 1)
|
| 246 |
+
error_dict['kid-MVE'] = round(sum(kid_mve)/(len(kid_mve)-num_processes), 1)
|
| 247 |
+
error_dict['kid-NMVE'] = round(error_dict['kid-MVE'] / (kid_f1), 1)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
if accelerator.is_main_process:
|
| 251 |
+
with open(os.path.join(results_save_path,'results.txt'),'w') as f:
|
| 252 |
+
for k,v in error_dict.items():
|
| 253 |
+
f.write(f'{k}: {v}\n')
|
| 254 |
+
|
| 255 |
+
return error_dict
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def test_agora(model, eval_dataloader, conf_thresh,
|
| 259 |
+
vis = True, vis_step = 400, results_save_path = None,
|
| 260 |
+
distributed = False, accelerator = None):
|
| 261 |
+
assert results_save_path is not None
|
| 262 |
+
assert accelerator is not None
|
| 263 |
+
|
| 264 |
+
os.makedirs(os.path.join(results_save_path,'predictions'),exist_ok=True)
|
| 265 |
+
if vis:
|
| 266 |
+
imgs_save_dir = os.path.join(results_save_path, 'imgs')
|
| 267 |
+
os.makedirs(imgs_save_dir, exist_ok = True)
|
| 268 |
+
step = 0
|
| 269 |
+
cur_device = next(model.parameters()).device
|
| 270 |
+
smpl_layer = model.human_model
|
| 271 |
+
|
| 272 |
+
progress_bar = tqdm(total=len(eval_dataloader), disable=not accelerator.is_local_main_process)
|
| 273 |
+
progress_bar.set_description('testing')
|
| 274 |
+
for itr, (samples, targets) in enumerate(eval_dataloader):
|
| 275 |
+
samples=[sample.to(device = cur_device, non_blocking = True) for sample in samples]
|
| 276 |
+
with torch.no_grad():
|
| 277 |
+
outputs = model(samples, targets)
|
| 278 |
+
bs = len(targets)
|
| 279 |
+
for idx in range(bs):
|
| 280 |
+
#gt
|
| 281 |
+
img_name = targets[idx]['img_name'].split('.')[0]
|
| 282 |
+
#pred
|
| 283 |
+
select_queries_idx = torch.where(outputs['pred_confs'][idx] > conf_thresh)[0]
|
| 284 |
+
pred_j2ds = np.array(outputs['pred_j2ds'][idx][select_queries_idx].detach().to('cpu'))[:,:24,:]*(3840/model.input_size)
|
| 285 |
+
pred_j3ds = np.array(outputs['pred_j3ds'][idx][select_queries_idx].detach().to('cpu'))[:,:24,:]
|
| 286 |
+
pred_verts = np.array(outputs['pred_verts'][idx][select_queries_idx].detach().to('cpu'))
|
| 287 |
+
pred_poses = np.array(outputs['pred_poses'][idx][select_queries_idx].detach().to('cpu'))
|
| 288 |
+
pred_betas = np.array(outputs['pred_betas'][idx][select_queries_idx].detach().to('cpu'))
|
| 289 |
+
|
| 290 |
+
#visualization
|
| 291 |
+
step+=1
|
| 292 |
+
img_idx = step + accelerator.process_index*len(eval_dataloader)*bs
|
| 293 |
+
if vis and (img_idx%vis_step == 0):
|
| 294 |
+
ori_img = tensor_to_BGR(unNormalize(samples[idx]).cpu())
|
| 295 |
+
ori_img = pad_img(ori_img, model.input_size)
|
| 296 |
+
|
| 297 |
+
sat_img = vis_sat(ori_img.copy(),
|
| 298 |
+
input_size = model.input_size,
|
| 299 |
+
patch_size = 14,
|
| 300 |
+
sat_dict = outputs['sat'],
|
| 301 |
+
bid = idx)
|
| 302 |
+
|
| 303 |
+
colors = get_colors_rgb(len(pred_verts))
|
| 304 |
+
mesh_img = vis_meshes_img(img = ori_img.copy(),
|
| 305 |
+
verts = pred_verts,
|
| 306 |
+
smpl_faces = smpl_layer.faces,
|
| 307 |
+
colors = colors,
|
| 308 |
+
cam_intrinsics = outputs['pred_intrinsics'][idx].detach().cpu())
|
| 309 |
+
|
| 310 |
+
if 'enc_outputs' not in outputs:
|
| 311 |
+
pred_scale_img = np.zeros_like(ori_img)
|
| 312 |
+
else:
|
| 313 |
+
enc_out = outputs['enc_outputs']
|
| 314 |
+
h, w = enc_out['hw'][idx]
|
| 315 |
+
flatten_map = enc_out['scale_map'].split(enc_out['lens'])[idx].detach().cpu()
|
| 316 |
+
|
| 317 |
+
ys = enc_out['pos_y'].split(enc_out['lens'])[idx]
|
| 318 |
+
xs = enc_out['pos_x'].split(enc_out['lens'])[idx]
|
| 319 |
+
scale_map = torch.zeros((h,w,2))
|
| 320 |
+
scale_map[ys,xs] = flatten_map
|
| 321 |
+
pred_scale_img = vis_scale_img(img = ori_img.copy(),
|
| 322 |
+
scale_map = scale_map,
|
| 323 |
+
conf_thresh = model.sat_cfg['conf_thresh'],
|
| 324 |
+
patch_size=28)
|
| 325 |
+
|
| 326 |
+
full_img = np.vstack([np.hstack([ori_img, mesh_img]),
|
| 327 |
+
np.hstack([pred_scale_img, sat_img])])
|
| 328 |
+
cv2.imwrite(os.path.join(imgs_save_dir, f'{img_idx}_{img_name}.jpg'), full_img)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# submit
|
| 332 |
+
for pnum in range(len(pred_j2ds)):
|
| 333 |
+
smpl_dict = {}
|
| 334 |
+
# smpl_dict['age'] = 'kid'
|
| 335 |
+
smpl_dict['joints'] = pred_j2ds[pnum].reshape(24,2)
|
| 336 |
+
smpl_dict['params'] = {'transl': np.zeros((1,3)),
|
| 337 |
+
'betas': pred_betas[pnum].reshape(1,10),
|
| 338 |
+
'global_orient': pred_poses[pnum][:3].reshape(1,1,3),
|
| 339 |
+
'body_pose': pred_poses[pnum][3:].reshape(1,23,3)}
|
| 340 |
+
# smpl_dict['verts'] = pred_verts[pnum].reshape(6890,3)
|
| 341 |
+
# smpl_dict['allSmplJoints3d'] = pred_j3ds[pnum].reshape(24,3)
|
| 342 |
+
with open(os.path.join(results_save_path,'predictions',f'{img_name}_personId_{pnum}.pkl'), 'wb') as f:
|
| 343 |
+
pickle.dump(smpl_dict, f)
|
| 344 |
+
|
| 345 |
+
progress_bar.update(1)
|
| 346 |
+
|
| 347 |
+
accelerator.print('Packing...')
|
| 348 |
+
|
| 349 |
+
folder_path = os.path.join(results_save_path,'predictions')
|
| 350 |
+
now = datetime.datetime.now()
|
| 351 |
+
timestamp = now.strftime("%Y%m%d_%H%M%S")
|
| 352 |
+
output_path = os.path.join(results_save_path,f'pred_{timestamp}.zip')
|
| 353 |
+
with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
| 354 |
+
for root, dirs, files in os.walk(folder_path):
|
| 355 |
+
for file in files:
|
| 356 |
+
file_path = os.path.join(root, file)
|
| 357 |
+
arcname = os.path.relpath(file_path, os.path.dirname(folder_path))
|
| 358 |
+
zipf.write(file_path, arcname)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
return 'Results saved at: ' + os.path.join(results_save_path,'predictions')
|
| 362 |
+
|
engines/funcs/infer_funcs.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
from tqdm.auto import tqdm
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from utils.transforms import unNormalize
|
| 7 |
+
from utils.visualization import tensor_to_BGR, pad_img
|
| 8 |
+
from utils.visualization import vis_meshes_img, vis_boxes, vis_sat, vis_scale_img, get_colors_rgb
|
| 9 |
+
from utils.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh
|
| 10 |
+
import time
|
| 11 |
+
import cv2
|
| 12 |
+
import trimesh
|
| 13 |
+
|
| 14 |
+
def inference(model, infer_dataloader, conf_thresh, results_save_path = None,
|
| 15 |
+
distributed = False, accelerator = None):
|
| 16 |
+
assert results_save_path is not None
|
| 17 |
+
assert accelerator is not None
|
| 18 |
+
|
| 19 |
+
accelerator.print(f'Results will be saved at: {results_save_path}')
|
| 20 |
+
os.makedirs(results_save_path,exist_ok=True)
|
| 21 |
+
cur_device = next(model.parameters()).device
|
| 22 |
+
smpl_layer = model.human_model
|
| 23 |
+
|
| 24 |
+
progress_bar = tqdm(total=len(infer_dataloader), disable=not accelerator.is_local_main_process)
|
| 25 |
+
progress_bar.set_description('inference')
|
| 26 |
+
|
| 27 |
+
for itr, (samples, targets) in enumerate(infer_dataloader):
|
| 28 |
+
samples=[sample.to(device = cur_device, non_blocking = True) for sample in samples]
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
outputs = model(samples, targets)
|
| 31 |
+
bs = len(targets)
|
| 32 |
+
for idx in range(bs):
|
| 33 |
+
img_size = targets[idx]['img_size'].detach().cpu().int().numpy()
|
| 34 |
+
img_name = targets[idx]['img_path'].split('/')[-1].split('.')[0]
|
| 35 |
+
|
| 36 |
+
#pred
|
| 37 |
+
select_queries_idx = torch.where(outputs['pred_confs'][idx] > conf_thresh)[0]
|
| 38 |
+
pred_verts = outputs['pred_verts'][idx][select_queries_idx].detach().cpu().numpy()
|
| 39 |
+
|
| 40 |
+
ori_img = tensor_to_BGR(unNormalize(samples[idx]).cpu())
|
| 41 |
+
ori_img[img_size[0]:,:,:] = 255
|
| 42 |
+
ori_img[:,img_size[1]:,:] = 255
|
| 43 |
+
ori_img[img_size[0]:,img_size[1]:,:] = 255
|
| 44 |
+
ori_img = pad_img(ori_img, model.input_size, pad_color_offset=255)
|
| 45 |
+
|
| 46 |
+
sat_img = vis_sat(ori_img.copy(),
|
| 47 |
+
input_size = model.input_size,
|
| 48 |
+
patch_size = 14,
|
| 49 |
+
sat_dict = outputs['sat'],
|
| 50 |
+
bid = idx)[:img_size[0],:img_size[1]]
|
| 51 |
+
|
| 52 |
+
colors = get_colors_rgb(len(pred_verts))
|
| 53 |
+
pred_mesh_img = vis_meshes_img(img = ori_img.copy(),
|
| 54 |
+
verts = pred_verts,
|
| 55 |
+
smpl_faces = smpl_layer.faces,
|
| 56 |
+
cam_intrinsics = outputs['pred_intrinsics'][idx].reshape(3,3).detach().cpu(),
|
| 57 |
+
colors=colors)[:img_size[0],:img_size[1]]
|
| 58 |
+
|
| 59 |
+
if 'enc_outputs' not in outputs:
|
| 60 |
+
pred_scale_img = np.zeros_like(ori_img)[:img_size[0],:img_size[1]]
|
| 61 |
+
else:
|
| 62 |
+
enc_out = outputs['enc_outputs']
|
| 63 |
+
h, w = enc_out['hw'][idx]
|
| 64 |
+
flatten_map = enc_out['scale_map'].split(enc_out['lens'])[idx].detach().cpu()
|
| 65 |
+
|
| 66 |
+
ys = enc_out['pos_y'].split(enc_out['lens'])[idx]
|
| 67 |
+
xs = enc_out['pos_x'].split(enc_out['lens'])[idx]
|
| 68 |
+
scale_map = torch.zeros((h,w,2))
|
| 69 |
+
scale_map[ys,xs] = flatten_map
|
| 70 |
+
|
| 71 |
+
pred_scale_img = vis_scale_img(img = ori_img.copy(),
|
| 72 |
+
scale_map = scale_map,
|
| 73 |
+
conf_thresh = model.sat_cfg['conf_thresh'],
|
| 74 |
+
patch_size=28)[:img_size[0],:img_size[1]]
|
| 75 |
+
|
| 76 |
+
pred_boxes = outputs['pred_boxes'][idx][select_queries_idx].detach().cpu()
|
| 77 |
+
pred_boxes = box_cxcywh_to_xyxy(pred_boxes) * model.input_size
|
| 78 |
+
pred_box_img = vis_boxes(ori_img.copy(), pred_boxes, color = (255,0,255))[:img_size[0],:img_size[1]]
|
| 79 |
+
|
| 80 |
+
cv2.imwrite(os.path.join(results_save_path, f'{img_name}.png'), np.vstack([np.hstack([pred_box_img, pred_mesh_img]),
|
| 81 |
+
np.hstack([pred_scale_img, sat_img])]))
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
progress_bar.update(1)
|
| 85 |
+
|
| 86 |
+
|
figures/pipeline.png
ADDED
|
Git LFS Details
|
figures/qualitative_results.png
ADDED
|
Git LFS Details
|
figures/results.png
ADDED
|
Git LFS Details
|
figures/results_3d.gif
ADDED
|
Git LFS Details
|
main.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import yaml
|
| 4 |
+
import numpy as np
|
| 5 |
+
from engines.engine import Engine
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_args_parser():
|
| 9 |
+
parser = argparse.ArgumentParser('SAT-HMR', add_help=False)
|
| 10 |
+
parser.add_argument('--cfg', default=None, type=str)
|
| 11 |
+
parser.add_argument('--mode',default='train',type=str)
|
| 12 |
+
|
| 13 |
+
return parser
|
| 14 |
+
|
| 15 |
+
def update_args(args, cfg_path):
|
| 16 |
+
with open(cfg_path) as f:
|
| 17 |
+
config = yaml.safe_load(f)
|
| 18 |
+
args_dict = vars(args)
|
| 19 |
+
args_dict.update(config)
|
| 20 |
+
args = argparse.Namespace(**args_dict)
|
| 21 |
+
return args
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if __name__ == '__main__':
|
| 25 |
+
parser = argparse.ArgumentParser('SAT-HMR training and evaluation script', parents=[get_args_parser()])
|
| 26 |
+
args = parser.parse_args()
|
| 27 |
+
assert args.cfg is not None
|
| 28 |
+
args = update_args(args, os.path.join('configs', 'run', f'{args.cfg}.yaml'))
|
| 29 |
+
args.exp_name = args.cfg
|
| 30 |
+
args = update_args(args, os.path.join('configs', 'models', f'{args.model}.yaml'))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if args.mode.lower() == 'train':
|
| 34 |
+
raise NotImplementedError
|
| 35 |
+
from accelerate.utils import set_seed
|
| 36 |
+
seed = args.seed
|
| 37 |
+
set_seed(args.seed)
|
| 38 |
+
engine = Engine(args, mode='train')
|
| 39 |
+
engine.train()
|
| 40 |
+
|
| 41 |
+
elif args.mode.lower() == 'eval':
|
| 42 |
+
raise NotImplementedError
|
| 43 |
+
engine = Engine(args, mode='eval')
|
| 44 |
+
engine.eval()
|
| 45 |
+
|
| 46 |
+
elif args.mode.lower() == 'infer':
|
| 47 |
+
engine = Engine(args, mode='infer')
|
| 48 |
+
engine.infer()
|
| 49 |
+
|
| 50 |
+
else:
|
| 51 |
+
print('Wrong mode!')
|
| 52 |
+
exit(1)
|
models/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Modified from DAB-DETR (https://github.com/IDEA-Research/DAB-DETR)
|
| 3 |
+
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
|
| 7 |
+
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 9 |
+
# ------------------------------------------------------------------------
|
| 10 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 11 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
| 12 |
+
# ------------------------------------------------------------------------
|
| 13 |
+
|
| 14 |
+
from .sat_model import build_sat_model
|
| 15 |
+
|
| 16 |
+
|
models/criterion.py
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from DAB-DETR (https://github.com/IDEA-Research/DAB-DETR)
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from utils import box_ops
|
| 7 |
+
from utils.misc import (NestedTensor, nested_tensor_from_tensor_list,
|
| 8 |
+
accuracy, get_world_size, interpolate,
|
| 9 |
+
is_dist_avail_and_initialized, inverse_sigmoid)
|
| 10 |
+
|
| 11 |
+
def focal_loss(inputs, targets, valid_mask = None, alpha: float = 0.25, gamma: float = 2):
|
| 12 |
+
"""
|
| 13 |
+
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
| 14 |
+
Args:
|
| 15 |
+
inputs: A float tensor of arbitrary shape.
|
| 16 |
+
The predictions for each example.
|
| 17 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
| 18 |
+
classification label for each element in inputs
|
| 19 |
+
(0 for the negative class and 1 for the positive class).
|
| 20 |
+
alpha: (optional) Weighting factor in range (0,1) to balance
|
| 21 |
+
positive vs negative examples. Default = -1 (no weighting).
|
| 22 |
+
gamma: Exponent of the modulating factor (1 - p_t) to
|
| 23 |
+
balance easy vs hard examples.
|
| 24 |
+
Returns:
|
| 25 |
+
Loss tensor
|
| 26 |
+
"""
|
| 27 |
+
# prob = inputs.sigmoid()
|
| 28 |
+
# ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
| 29 |
+
prob = inputs
|
| 30 |
+
ce_loss = F.binary_cross_entropy(inputs, targets, reduction="none")
|
| 31 |
+
p_t = prob * targets + (1 - prob) * (1 - targets)
|
| 32 |
+
loss = ce_loss * ((1 - p_t) ** gamma)
|
| 33 |
+
|
| 34 |
+
if alpha >= 0:
|
| 35 |
+
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
| 36 |
+
loss = alpha_t * loss
|
| 37 |
+
|
| 38 |
+
# if valid_mask is not None:
|
| 39 |
+
# loss = loss * valid_mask
|
| 40 |
+
|
| 41 |
+
return loss.mean()
|
| 42 |
+
|
| 43 |
+
class SetCriterion(nn.Module):
|
| 44 |
+
""" This class computes the loss for DETR.
|
| 45 |
+
The process happens in two steps:
|
| 46 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
| 47 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
| 48 |
+
"""
|
| 49 |
+
def __init__(self, matcher, weight_dict, losses = ['confs','boxes', 'poses','betas', 'j3ds','j2ds', 'depths', 'kid_offsets'],
|
| 50 |
+
focal_alpha=0.25, focal_gamma = 2.0, j2ds_norm_scale = 518):
|
| 51 |
+
""" Create the criterion.
|
| 52 |
+
Parameters:
|
| 53 |
+
num_classes: number of object categories, omitting the special no-object category
|
| 54 |
+
matcher: module able to compute a matching between targets and proposals
|
| 55 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
| 56 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
| 57 |
+
focal_alpha: alpha in Focal Loss
|
| 58 |
+
"""
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.matcher = matcher
|
| 61 |
+
self.losses = losses
|
| 62 |
+
if 'boxes' in losses and 'giou' not in weight_dict:
|
| 63 |
+
weight_dict.update({'giou': weight_dict['boxes']})
|
| 64 |
+
self.weight_dict = weight_dict
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
self.betas_weight = torch.tensor([2.56, 1.28, 0.64, 0.64, 0.32, 0.32, 0.32, 0.32, 0.32, 0.32]).unsqueeze(0).float()
|
| 68 |
+
self.focal_alpha = focal_alpha
|
| 69 |
+
self.focal_gamma = focal_gamma
|
| 70 |
+
self.j2ds_norm_scale = j2ds_norm_scale
|
| 71 |
+
self.device = None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def loss_boxes(self, loss, outputs, targets, indices, num_instances, **kwargs):
|
| 75 |
+
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
|
| 76 |
+
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
|
| 77 |
+
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
|
| 78 |
+
"""
|
| 79 |
+
assert 'pred_boxes' in outputs
|
| 80 |
+
assert loss == 'boxes'
|
| 81 |
+
idx = self._get_src_permutation_idx(indices)
|
| 82 |
+
valid_idx = torch.where(torch.cat([torch.ones(len(i), dtype=bool, device = self.device)*(loss in t) for t, (_, i) in zip(targets, indices)]))[0]
|
| 83 |
+
|
| 84 |
+
if len(valid_idx) == 0:
|
| 85 |
+
return {loss: torch.tensor(0.).to(self.device)}
|
| 86 |
+
|
| 87 |
+
src = outputs['pred_'+loss][idx][valid_idx]
|
| 88 |
+
target = torch.cat([t[loss][i] for t, (_, i) in zip(targets, indices) if loss in t], dim=0)
|
| 89 |
+
assert src.shape == target.shape
|
| 90 |
+
|
| 91 |
+
src_boxes = src
|
| 92 |
+
target_boxes = target
|
| 93 |
+
|
| 94 |
+
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
|
| 95 |
+
|
| 96 |
+
losses = {}
|
| 97 |
+
losses['boxes'] = loss_bbox.sum() / num_instances
|
| 98 |
+
|
| 99 |
+
loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
|
| 100 |
+
box_ops.box_cxcywh_to_xyxy(src_boxes),
|
| 101 |
+
box_ops.box_cxcywh_to_xyxy(target_boxes)))
|
| 102 |
+
losses['giou'] = loss_giou.sum() / num_instances
|
| 103 |
+
|
| 104 |
+
# # calculate the x,y and h,w loss
|
| 105 |
+
# with torch.no_grad():
|
| 106 |
+
# losses['loss_xy'] = loss_bbox[..., :2].sum() / num_boxes
|
| 107 |
+
# losses['loss_hw'] = loss_bbox[..., 2:].sum() / num_boxes
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
return losses
|
| 111 |
+
|
| 112 |
+
def loss_boxes_enc(self, loss, outputs, targets, indices, num_instances, **kwargs):
|
| 113 |
+
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
|
| 114 |
+
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
|
| 115 |
+
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
|
| 116 |
+
"""
|
| 117 |
+
assert 'pred_boxes' in outputs
|
| 118 |
+
assert loss == 'boxes_enc'
|
| 119 |
+
loss = 'boxes'
|
| 120 |
+
|
| 121 |
+
valid_idx = torch.where(torch.cat([torch.ones(len(i), dtype=bool, device = self.device)*(loss in t) for t, (_, i) in zip(targets, indices)]))[0]
|
| 122 |
+
|
| 123 |
+
if len(valid_idx) == 0:
|
| 124 |
+
return {loss: torch.tensor(0.).to(self.device)}
|
| 125 |
+
|
| 126 |
+
lens = outputs['lens']
|
| 127 |
+
pred_boxes = outputs['pred_boxes']
|
| 128 |
+
src = torch.cat([s[i] for s, (i, _) in zip(pred_boxes.split(lens), indices)], dim=0)[valid_idx]
|
| 129 |
+
target = torch.cat([t[loss][i] for t, (_, i) in zip(targets, indices) if loss in t], dim=0)
|
| 130 |
+
assert src.shape == target.shape
|
| 131 |
+
|
| 132 |
+
src_boxes = src
|
| 133 |
+
target_boxes = target
|
| 134 |
+
|
| 135 |
+
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
|
| 136 |
+
|
| 137 |
+
losses = {}
|
| 138 |
+
losses['boxes'] = loss_bbox.sum() / num_instances
|
| 139 |
+
|
| 140 |
+
loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
|
| 141 |
+
box_ops.box_cxcywh_to_xyxy(src_boxes),
|
| 142 |
+
box_ops.box_cxcywh_to_xyxy(target_boxes)))
|
| 143 |
+
losses['giou'] = loss_giou.sum() / num_instances
|
| 144 |
+
|
| 145 |
+
# # calculate the x,y and h,w loss
|
| 146 |
+
# with torch.no_grad():
|
| 147 |
+
# losses['loss_xy'] = loss_bbox[..., :2].sum() / num_boxes
|
| 148 |
+
# losses['loss_hw'] = loss_bbox[..., 2:].sum() / num_boxes
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
return losses
|
| 152 |
+
|
| 153 |
+
# For computing ['boxes', 'poses', 'betas', 'j3ds', 'j2ds'] losses
|
| 154 |
+
def loss_L1(self, loss, outputs, targets, indices, num_instances, **kwargs):
|
| 155 |
+
idx = self._get_src_permutation_idx(indices)
|
| 156 |
+
valid_idx = torch.where(torch.cat([torch.ones(len(i), dtype=bool, device = self.device)*(loss in t) for t, (_, i) in zip(targets, indices)]))[0]
|
| 157 |
+
|
| 158 |
+
if len(valid_idx) == 0:
|
| 159 |
+
return {loss: torch.tensor(0.).to(self.device)}
|
| 160 |
+
|
| 161 |
+
src = outputs['pred_'+loss][idx][valid_idx]
|
| 162 |
+
target = torch.cat([t[loss][i] for t, (_, i) in zip(targets, indices) if loss in t], dim=0)
|
| 163 |
+
assert src.shape == target.shape
|
| 164 |
+
|
| 165 |
+
losses = {}
|
| 166 |
+
loss_mask = None
|
| 167 |
+
|
| 168 |
+
if loss == 'j3ds':
|
| 169 |
+
# Root aligned
|
| 170 |
+
src = src - src[...,[0],:].clone()
|
| 171 |
+
target = target - target[...,[0],:].clone()
|
| 172 |
+
# Use 54 smpl joints
|
| 173 |
+
src = src[:,:54,:]
|
| 174 |
+
target = target[:,:54,:]
|
| 175 |
+
elif loss == 'j2ds':
|
| 176 |
+
src = src / self.j2ds_norm_scale
|
| 177 |
+
target = target / self.j2ds_norm_scale
|
| 178 |
+
# Need to exclude invalid kpts in 2d datasets
|
| 179 |
+
loss_mask = torch.cat([t['j2ds_mask'][i] for t, (_, i) in zip(targets, indices) if 'j2ds' in t], dim=0)
|
| 180 |
+
# Use 54 smpl joints
|
| 181 |
+
src = src[:,:54,:]
|
| 182 |
+
target = target[:,:54,:]
|
| 183 |
+
loss_mask = loss_mask[:,:54,:]
|
| 184 |
+
|
| 185 |
+
valid_loss = torch.abs(src-target)
|
| 186 |
+
|
| 187 |
+
# if loss == 'j2ds':
|
| 188 |
+
# print(src.shape)
|
| 189 |
+
# print(target.shape)
|
| 190 |
+
# print(num_instances)
|
| 191 |
+
# exit(0)
|
| 192 |
+
|
| 193 |
+
if loss_mask is not None:
|
| 194 |
+
valid_loss = valid_loss * loss_mask
|
| 195 |
+
if loss == 'betas':
|
| 196 |
+
valid_loss = valid_loss*self.betas_weight.to(src.device)
|
| 197 |
+
|
| 198 |
+
losses[loss] = valid_loss.flatten(1).mean(-1).sum()/num_instances
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
return losses
|
| 203 |
+
|
| 204 |
+
def loss_scale_map(self, loss, outputs, targets, indices, num_instances, **kwargs):
|
| 205 |
+
assert loss == 'scale_map'
|
| 206 |
+
|
| 207 |
+
pred_map = outputs['enc_outputs']['scale_map']
|
| 208 |
+
tgt_map = torch.cat([t['scale_map'] for t in targets], dim=0)
|
| 209 |
+
assert pred_map.shape == tgt_map.shape
|
| 210 |
+
|
| 211 |
+
labels = tgt_map[:,0]
|
| 212 |
+
pred_scales = pred_map[:,1]
|
| 213 |
+
tgt_scales = tgt_map[:, 1]
|
| 214 |
+
|
| 215 |
+
detection_valid_mask = labels.bool()
|
| 216 |
+
cur = 0
|
| 217 |
+
lens = [len(t['scale_map']) for t in targets]
|
| 218 |
+
for i, tgt in enumerate(targets):
|
| 219 |
+
if tgt['detect_all_people']:
|
| 220 |
+
detection_valid_mask[cur:cur+lens[i]] = True
|
| 221 |
+
cur += lens[i]
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
losses = {}
|
| 225 |
+
losses['map_confs'] = focal_loss(pred_map[:,0], labels, valid_mask=detection_valid_mask)/1.
|
| 226 |
+
losses['map_scales'] = torch.abs((pred_scales - tgt_scales)[torch.where(labels)[0]]).sum()/num_instances
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
return losses
|
| 230 |
+
|
| 231 |
+
def loss_confs(self, loss, outputs, targets, indices, num_instances, is_dn=False, **kwargs):
|
| 232 |
+
assert loss == 'confs'
|
| 233 |
+
idx = self._get_src_permutation_idx(indices)
|
| 234 |
+
pred_confs = outputs['pred_'+loss]
|
| 235 |
+
|
| 236 |
+
with torch.no_grad():
|
| 237 |
+
labels = torch.zeros_like(pred_confs)
|
| 238 |
+
labels[idx] = 1
|
| 239 |
+
detection_valid_mask = torch.zeros_like(pred_confs,dtype=bool)
|
| 240 |
+
detection_valid_mask[idx] = True
|
| 241 |
+
valid_batch_idx = torch.where(torch.tensor([t['detect_all_people'] for t in targets]))[0]
|
| 242 |
+
detection_valid_mask[valid_batch_idx] = True
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
losses = {}
|
| 246 |
+
if is_dn:
|
| 247 |
+
losses[loss] = focal_loss(pred_confs, labels) / num_instances
|
| 248 |
+
else:
|
| 249 |
+
losses[loss] = focal_loss(pred_confs, labels, valid_mask = detection_valid_mask) / num_instances
|
| 250 |
+
|
| 251 |
+
return losses
|
| 252 |
+
|
| 253 |
+
def loss_confs_enc(self, loss, outputs, targets, indices, num_instances, **kwargs):
|
| 254 |
+
assert loss == 'confs_enc'
|
| 255 |
+
loss = 'confs'
|
| 256 |
+
|
| 257 |
+
lens = outputs['lens']
|
| 258 |
+
pred_confs = outputs['pred_confs']
|
| 259 |
+
detection_valid_mask = torch.zeros_like(pred_confs,dtype=bool)
|
| 260 |
+
labels = torch.zeros_like(pred_confs)
|
| 261 |
+
|
| 262 |
+
cur = 0
|
| 263 |
+
idx = []
|
| 264 |
+
for i, (src, tgt) in enumerate(indices):
|
| 265 |
+
idx += (src + cur).tolist()
|
| 266 |
+
if targets[i]['detect_all_people']:
|
| 267 |
+
detection_valid_mask[cur:cur+lens[i]] = True
|
| 268 |
+
cur += lens[i]
|
| 269 |
+
detection_valid_mask[idx] = True
|
| 270 |
+
labels[idx] = 1
|
| 271 |
+
|
| 272 |
+
pred_confs = pred_confs.unsqueeze(0)
|
| 273 |
+
labels = labels.unsqueeze(0)
|
| 274 |
+
detection_valid_mask = detection_valid_mask.unsqueeze(0)
|
| 275 |
+
|
| 276 |
+
losses = {}
|
| 277 |
+
# losses[loss] = focal_loss(pred_confs, labels, valid_mask = detection_valid_mask)
|
| 278 |
+
losses[loss] = focal_loss(pred_confs, labels)
|
| 279 |
+
return losses
|
| 280 |
+
|
| 281 |
+
def loss_L2(self, loss, outputs, targets, indices, num_instances, **kwargs):
|
| 282 |
+
pass
|
| 283 |
+
|
| 284 |
+
def loss_absolute_depths(self, loss, outputs, targets, indices, num_instances, **kwargs):
|
| 285 |
+
assert loss == 'depths'
|
| 286 |
+
losses = {}
|
| 287 |
+
idx = self._get_src_permutation_idx(indices)
|
| 288 |
+
valid_idx = torch.where(torch.cat([torch.ones(len(i), dtype=bool, device = self.device)*(loss in t) for t, (_, i) in zip(targets, indices)]))[0]
|
| 289 |
+
|
| 290 |
+
if len(valid_idx) == 0:
|
| 291 |
+
return {loss: torch.tensor(0.).to(self.device)}
|
| 292 |
+
|
| 293 |
+
src = outputs['pred_'+loss][idx][valid_idx][...,[1]] # [d d/f]
|
| 294 |
+
target = torch.cat([t[loss][i] for t, (_, i) in zip(targets, indices) if loss in t], dim=0)[...,[0]]
|
| 295 |
+
target_focals = torch.cat([t['focals'][i] for t, (_, i) in zip(targets, indices) if loss in t], dim=0)
|
| 296 |
+
|
| 297 |
+
# print(src.shape, target.shape, target_focals.shape)
|
| 298 |
+
|
| 299 |
+
src = target_focals * src
|
| 300 |
+
|
| 301 |
+
assert src.shape == target.shape
|
| 302 |
+
|
| 303 |
+
valid_loss = torch.abs(1./(src + 1e-8) - 1./(target + 1e-8))
|
| 304 |
+
losses[loss] = valid_loss.flatten(1).mean(-1).sum()/num_instances
|
| 305 |
+
return losses
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def _get_src_permutation_idx(self, indices):
|
| 309 |
+
# permute predictions following indices
|
| 310 |
+
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
| 311 |
+
src_idx = torch.cat([src for (src, _) in indices])
|
| 312 |
+
return batch_idx, src_idx
|
| 313 |
+
|
| 314 |
+
def _get_tgt_permutation_idx(self, indices):
|
| 315 |
+
# permute targets following indices
|
| 316 |
+
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
| 317 |
+
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
| 318 |
+
return batch_idx, tgt_idx
|
| 319 |
+
|
| 320 |
+
def get_loss(self, loss, outputs, targets, indices, num_instances, **kwargs):
|
| 321 |
+
loss_map = {
|
| 322 |
+
'confs': self.loss_confs,
|
| 323 |
+
'boxes': self.loss_boxes,
|
| 324 |
+
'confs_enc': self.loss_confs_enc,
|
| 325 |
+
'boxes_enc': self.loss_boxes_enc,
|
| 326 |
+
'poses': self.loss_L1,
|
| 327 |
+
'betas': self.loss_L1,
|
| 328 |
+
'j3ds': self.loss_L1,
|
| 329 |
+
'j2ds': self.loss_L1,
|
| 330 |
+
'depths': self.loss_absolute_depths,
|
| 331 |
+
'scale_map': self.loss_scale_map,
|
| 332 |
+
}
|
| 333 |
+
# assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
| 334 |
+
return loss_map[loss](loss, outputs, targets, indices, num_instances, **kwargs)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def get_valid_instances(self, targets):
|
| 338 |
+
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
| 339 |
+
# Losses: 'confs','centers','anchors', 'poses', 'betas', 'j3ds', 'j2ds', 'depths', 'ages', 'heatmap'
|
| 340 |
+
num_valid_instances = {}
|
| 341 |
+
for loss in self.losses:
|
| 342 |
+
num_instances = 0
|
| 343 |
+
if loss != 'scale_map':
|
| 344 |
+
for t in targets:
|
| 345 |
+
num_instances += t['pnum'] if loss in t else 0
|
| 346 |
+
num_instances = torch.as_tensor([num_instances], dtype=torch.float, device=self.device)
|
| 347 |
+
else:
|
| 348 |
+
for t in targets:
|
| 349 |
+
num_instances += t['scale_map'][...,0].sum().item()
|
| 350 |
+
num_instances = torch.as_tensor([num_instances], dtype=torch.float, device=self.device)
|
| 351 |
+
if is_dist_avail_and_initialized():
|
| 352 |
+
torch.distributed.all_reduce(num_instances)
|
| 353 |
+
num_instances = torch.clamp(num_instances / get_world_size(), min=1).item()
|
| 354 |
+
num_valid_instances[loss] = num_instances
|
| 355 |
+
num_valid_instances['confs'] = 1.
|
| 356 |
+
return num_valid_instances
|
| 357 |
+
|
| 358 |
+
def prep_for_dn(self, dn_meta):
|
| 359 |
+
output_known = dn_meta['output_known']
|
| 360 |
+
num_dn_groups, pad_size = dn_meta['num_dn_group'], dn_meta['pad_size']
|
| 361 |
+
assert pad_size % num_dn_groups == 0
|
| 362 |
+
single_pad = pad_size//num_dn_groups
|
| 363 |
+
|
| 364 |
+
return output_known, single_pad, num_dn_groups
|
| 365 |
+
|
| 366 |
+
def forward(self, outputs, targets):
|
| 367 |
+
""" This performs the loss computation.
|
| 368 |
+
Parameters:
|
| 369 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
| 370 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
| 371 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
| 372 |
+
"""
|
| 373 |
+
# remove invalid information in targets
|
| 374 |
+
for t in targets:
|
| 375 |
+
if not t['3d_valid']:
|
| 376 |
+
for key in ['betas', 'kid_offsets', 'poses', 'j3ds', 'depths', 'focals']:
|
| 377 |
+
if key in t:
|
| 378 |
+
del t[key]
|
| 379 |
+
|
| 380 |
+
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs' and k != 'enc_outputs' and k != 'sat'}
|
| 381 |
+
# Retrieve the matching between the outputs of the last layer and the targets
|
| 382 |
+
indices = self.matcher(outputs_without_aux, targets)
|
| 383 |
+
self.device = outputs['pred_poses'].device
|
| 384 |
+
num_valid_instances = self.get_valid_instances(targets)
|
| 385 |
+
|
| 386 |
+
# Compute all the requested losses
|
| 387 |
+
losses = {}
|
| 388 |
+
|
| 389 |
+
# prepare for dn loss
|
| 390 |
+
if 'dn_meta' in outputs:
|
| 391 |
+
dn_meta = outputs['dn_meta']
|
| 392 |
+
output_known, single_pad, scalar = self.prep_for_dn(dn_meta)
|
| 393 |
+
|
| 394 |
+
dn_pos_idx = []
|
| 395 |
+
dn_neg_idx = []
|
| 396 |
+
for i in range(len(targets)):
|
| 397 |
+
assert len(targets[i]['boxes']) > 0
|
| 398 |
+
# t = torch.range(0, len(targets[i]['labels']) - 1).long().to(self.device)
|
| 399 |
+
t = torch.arange(0, len(targets[i]['labels'])).long().to(self.device)
|
| 400 |
+
t = t.unsqueeze(0).repeat(scalar, 1)
|
| 401 |
+
tgt_idx = t.flatten()
|
| 402 |
+
output_idx = (torch.tensor(range(scalar)) * single_pad).long().to(self.device).unsqueeze(1) + t
|
| 403 |
+
output_idx = output_idx.flatten()
|
| 404 |
+
|
| 405 |
+
dn_pos_idx.append((output_idx, tgt_idx))
|
| 406 |
+
dn_neg_idx.append((output_idx + single_pad // 2, tgt_idx))
|
| 407 |
+
|
| 408 |
+
l_dict = {}
|
| 409 |
+
for loss in self.losses:
|
| 410 |
+
if loss == 'scale_map':
|
| 411 |
+
continue
|
| 412 |
+
l_dict.update(self.get_loss(loss, output_known, targets, dn_pos_idx, num_valid_instances[loss]*scalar, is_dn=True))
|
| 413 |
+
|
| 414 |
+
l_dict = {k + f'_dn': v for k, v in l_dict.items()}
|
| 415 |
+
losses.update(l_dict)
|
| 416 |
+
|
| 417 |
+
for loss in self.losses:
|
| 418 |
+
losses.update(self.get_loss(loss, outputs, targets, indices, num_valid_instances[loss]))
|
| 419 |
+
|
| 420 |
+
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
| 421 |
+
if 'aux_outputs' in outputs:
|
| 422 |
+
for i, aux_outputs in enumerate(outputs['aux_outputs']):
|
| 423 |
+
indices = self.matcher(aux_outputs, targets)
|
| 424 |
+
for loss in self.losses:
|
| 425 |
+
if loss == 'scale_map':
|
| 426 |
+
continue
|
| 427 |
+
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_valid_instances[loss])
|
| 428 |
+
l_dict = {f'{k}.{i}': v for k, v in l_dict.items()}
|
| 429 |
+
losses.update(l_dict)
|
| 430 |
+
|
| 431 |
+
if 'dn_meta' in outputs:
|
| 432 |
+
if loss == 'scale_map':
|
| 433 |
+
continue
|
| 434 |
+
aux_outputs_known = output_known['aux_outputs'][i]
|
| 435 |
+
l_dict={}
|
| 436 |
+
for loss in self.losses:
|
| 437 |
+
l_dict.update(self.get_loss(loss, aux_outputs_known, targets, dn_pos_idx, num_valid_instances[loss]*scalar, is_dn=True))
|
| 438 |
+
l_dict = {k + f'_dn.{i}': v for k, v in l_dict.items()}
|
| 439 |
+
losses.update(l_dict)
|
| 440 |
+
|
| 441 |
+
# if 'scale_map' in outputs:
|
| 442 |
+
# enc_outputs = outputs['enc_outputs']
|
| 443 |
+
# indices = self.matcher.forward_enc(enc_outputs, targets)
|
| 444 |
+
# for loss in ['confs_enc', 'boxes_enc']:
|
| 445 |
+
# l_dict = self.get_loss(loss, enc_outputs, targets, indices, num_valid_instances[loss.replace('_enc','')])
|
| 446 |
+
# l_dict = {k + f'_enc': v for k, v in l_dict.items()}
|
| 447 |
+
# losses.update(l_dict)
|
| 448 |
+
|
| 449 |
+
return losses
|
models/decoder.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from DAB-DETR (https://github.com/IDEA-Research/DAB-DETR)
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import copy
|
| 5 |
+
import os
|
| 6 |
+
from typing import Optional, List
|
| 7 |
+
from utils.misc import inverse_sigmoid
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.nn.functional import scaled_dot_product_attention
|
| 12 |
+
from torch import nn, Tensor
|
| 13 |
+
from torch.nn.init import constant_
|
| 14 |
+
|
| 15 |
+
from .position_encoding import position_encoding_xy
|
| 16 |
+
|
| 17 |
+
from xformers.ops import memory_efficient_attention, fmha
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MLP(nn.Module):
|
| 22 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.num_layers = num_layers
|
| 27 |
+
h = [hidden_dim] * (num_layers - 1)
|
| 28 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
for i, layer in enumerate(self.layers):
|
| 32 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
| 33 |
+
return x
|
| 34 |
+
|
| 35 |
+
class TransformerDecoder(nn.Module):
|
| 36 |
+
def __init__(self, d_model=512, nhead=8, num_queries=300,
|
| 37 |
+
num_decoder_layers=6, dim_feedforward=2048, dropout=0.0,
|
| 38 |
+
activation="relu",
|
| 39 |
+
return_intermediate_dec=False, query_dim=4,
|
| 40 |
+
keep_query_pos=False, query_scale_type='cond_elewise',
|
| 41 |
+
modulate_hw_attn=True,
|
| 42 |
+
bbox_embed_diff_each_layer=True,
|
| 43 |
+
):
|
| 44 |
+
|
| 45 |
+
super().__init__()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
decoder_layer = XformerDecoderLayer(d_model, nhead, dim_feedforward,
|
| 49 |
+
dropout, activation, keep_query_pos=keep_query_pos)
|
| 50 |
+
decoder_norm = nn.LayerNorm(d_model)
|
| 51 |
+
self.decoder = XformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
| 52 |
+
return_intermediate=return_intermediate_dec,
|
| 53 |
+
d_model=d_model, query_dim=query_dim, keep_query_pos=keep_query_pos, query_scale_type=query_scale_type,
|
| 54 |
+
modulate_hw_attn=modulate_hw_attn,
|
| 55 |
+
bbox_embed_diff_each_layer=bbox_embed_diff_each_layer)
|
| 56 |
+
|
| 57 |
+
self._reset_parameters()
|
| 58 |
+
assert query_scale_type in ['cond_elewise', 'cond_scalar', 'fix_elewise']
|
| 59 |
+
|
| 60 |
+
self.d_model = d_model
|
| 61 |
+
self.nhead = nhead
|
| 62 |
+
self.dec_layers = num_decoder_layers
|
| 63 |
+
self.num_queries = num_queries
|
| 64 |
+
|
| 65 |
+
def _reset_parameters(self):
|
| 66 |
+
for p in self.parameters():
|
| 67 |
+
if p.dim() > 1:
|
| 68 |
+
nn.init.xavier_uniform_(p)
|
| 69 |
+
|
| 70 |
+
def mask2bias(self, mask, batch_size):
|
| 71 |
+
if mask is None:
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
assert mask.dtype == torch.bool
|
| 75 |
+
assert mask.ndim == 2
|
| 76 |
+
L, S = mask.shape[0], mask.shape[1]
|
| 77 |
+
pad_size = (S + 7) // 8 * 8
|
| 78 |
+
bias = torch.zeros((batch_size, self.nhead, L, pad_size), device = mask.device)[:,:,:,:S]
|
| 79 |
+
bias.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
|
| 80 |
+
return bias
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def forward(self, memory, memory_lens, tgt, tgt_lens, refpoint_embed, pos_embed, self_attn_mask):
|
| 84 |
+
self_attn_bias = self.mask2bias(self_attn_mask, batch_size=len(memory_lens))
|
| 85 |
+
hs, references = self.decoder(memory=memory, memory_lens=memory_lens,
|
| 86 |
+
tgt=tgt, tgt_lens=tgt_lens,
|
| 87 |
+
pos=pos_embed, refpoints_unsigmoid=refpoint_embed,
|
| 88 |
+
self_attn_bias = self_attn_bias)
|
| 89 |
+
return hs, references
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class XformerDecoder(nn.Module):
|
| 93 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=True,
|
| 94 |
+
d_model=512, query_dim=4, keep_query_pos=False, query_scale_type='cond_elewise',
|
| 95 |
+
modulate_hw_attn=False,
|
| 96 |
+
bbox_embed_diff_each_layer=False,
|
| 97 |
+
):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
| 100 |
+
self.num_layers = num_layers
|
| 101 |
+
self.norm = norm
|
| 102 |
+
self.return_intermediate = return_intermediate
|
| 103 |
+
assert return_intermediate
|
| 104 |
+
self.query_dim = query_dim
|
| 105 |
+
|
| 106 |
+
assert query_scale_type in ['cond_elewise', 'cond_scalar', 'fix_elewise']
|
| 107 |
+
self.query_scale_type = query_scale_type
|
| 108 |
+
if query_scale_type == 'cond_elewise':
|
| 109 |
+
self.query_scale = MLP(d_model, d_model, d_model, 2)
|
| 110 |
+
elif query_scale_type == 'cond_scalar':
|
| 111 |
+
self.query_scale = MLP(d_model, d_model, 1, 2)
|
| 112 |
+
elif query_scale_type == 'fix_elewise':
|
| 113 |
+
self.query_scale = nn.Embedding(num_layers, d_model)
|
| 114 |
+
else:
|
| 115 |
+
raise NotImplementedError("Unknown query_scale_type: {}".format(query_scale_type))
|
| 116 |
+
|
| 117 |
+
self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
|
| 118 |
+
|
| 119 |
+
self.bbox_embed = None
|
| 120 |
+
self.d_model = d_model
|
| 121 |
+
self.modulate_hw_attn = modulate_hw_attn
|
| 122 |
+
self.bbox_embed_diff_each_layer = bbox_embed_diff_each_layer
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
if modulate_hw_attn:
|
| 126 |
+
self.ref_anchor_head = MLP(d_model, d_model, 2, 2)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
if not keep_query_pos:
|
| 130 |
+
for layer_id in range(num_layers - 1):
|
| 131 |
+
self.layers[layer_id + 1].ca_qpos_proj = None
|
| 132 |
+
|
| 133 |
+
def forward(self, memory, memory_lens, tgt, tgt_lens,
|
| 134 |
+
pos: Optional[Tensor] = None,
|
| 135 |
+
refpoints_unsigmoid: Optional[Tensor] = None, # L_tgt, 4
|
| 136 |
+
self_attn_bias = None):
|
| 137 |
+
B, num_queries = len(tgt_lens), tgt_lens[0]
|
| 138 |
+
output = tgt
|
| 139 |
+
|
| 140 |
+
intermediate = []
|
| 141 |
+
reference_points = refpoints_unsigmoid.sigmoid()
|
| 142 |
+
ref_points = [reference_points.view(B, num_queries, self.query_dim)]
|
| 143 |
+
|
| 144 |
+
# import ipdb; ipdb.set_trace()
|
| 145 |
+
|
| 146 |
+
for layer_id, layer in enumerate(self.layers):
|
| 147 |
+
obj_center = reference_points[:, :self.query_dim] # [L_tgt, 4]
|
| 148 |
+
# get sine embedding for the query vector
|
| 149 |
+
xy_embed = position_encoding_xy(obj_center[:,0], obj_center[:,1], self.d_model)
|
| 150 |
+
wh_embed = position_encoding_xy(obj_center[:,2], obj_center[:,3], self.d_model)
|
| 151 |
+
query_sine_embed = torch.cat([xy_embed,wh_embed],dim=1) #[L_tgt, 2*d_model]
|
| 152 |
+
query_pos = self.ref_point_head(query_sine_embed)
|
| 153 |
+
|
| 154 |
+
# For the first decoder layer, we do not apply transformation over p_s
|
| 155 |
+
if self.query_scale_type != 'fix_elewise':
|
| 156 |
+
if layer_id == 0:
|
| 157 |
+
pos_transformation = 1
|
| 158 |
+
else:
|
| 159 |
+
pos_transformation = self.query_scale(output)
|
| 160 |
+
else:
|
| 161 |
+
pos_transformation = self.query_scale.weight[layer_id]
|
| 162 |
+
|
| 163 |
+
# apply transformation
|
| 164 |
+
query_sine_embed = query_sine_embed[:,:self.d_model] * pos_transformation
|
| 165 |
+
|
| 166 |
+
# modulated HW attentions
|
| 167 |
+
if self.modulate_hw_attn:
|
| 168 |
+
refHW_cond = self.ref_anchor_head(output).sigmoid() # nq, bs, 2
|
| 169 |
+
query_sine_embed[..., self.d_model // 2:] *= (refHW_cond[..., 0] / obj_center[..., 2]).unsqueeze(-1)
|
| 170 |
+
query_sine_embed[..., :self.d_model // 2] *= (refHW_cond[..., 1] / obj_center[..., 3]).unsqueeze(-1)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
output = layer(memory=memory, memory_lens=memory_lens,
|
| 174 |
+
tgt=output, tgt_lens=tgt_lens,
|
| 175 |
+
pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed,
|
| 176 |
+
is_first=(layer_id == 0),
|
| 177 |
+
self_attn_bias = self_attn_bias)
|
| 178 |
+
|
| 179 |
+
# iter update
|
| 180 |
+
if self.bbox_embed is not None:
|
| 181 |
+
if self.bbox_embed_diff_each_layer:
|
| 182 |
+
tmp = self.bbox_embed[layer_id](self.norm(output))
|
| 183 |
+
else:
|
| 184 |
+
tmp = self.bbox_embed(self.norm(output))
|
| 185 |
+
# import ipdb; ipdb.set_trace()
|
| 186 |
+
tmp[..., :self.query_dim] += inverse_sigmoid(reference_points)
|
| 187 |
+
new_reference_points = tmp[..., :self.query_dim].sigmoid()
|
| 188 |
+
if layer_id != self.num_layers - 1:
|
| 189 |
+
ref_points.append(new_reference_points.view(B, num_queries, self.query_dim))
|
| 190 |
+
reference_points = new_reference_points.detach()
|
| 191 |
+
|
| 192 |
+
if self.return_intermediate:
|
| 193 |
+
intermediate.append(self.norm(output).view(B, num_queries, self.d_model))
|
| 194 |
+
|
| 195 |
+
# if self.norm is not None:
|
| 196 |
+
# output = self.norm(output)
|
| 197 |
+
# if self.return_intermediate:
|
| 198 |
+
# intermediate.pop()
|
| 199 |
+
# intermediate.append(output.view(B, num_queries, self.d_model))
|
| 200 |
+
|
| 201 |
+
if self.return_intermediate:
|
| 202 |
+
if self.bbox_embed is not None:
|
| 203 |
+
return [
|
| 204 |
+
torch.stack(intermediate),
|
| 205 |
+
torch.stack(ref_points),
|
| 206 |
+
]
|
| 207 |
+
else:
|
| 208 |
+
return [
|
| 209 |
+
torch.stack(intermediate),
|
| 210 |
+
reference_points.unsqueeze(0)
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
return output.unsqueeze(0)
|
| 214 |
+
|
| 215 |
+
class XformerDecoderLayer(nn.Module):
|
| 216 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.0,
|
| 217 |
+
activation="relu", keep_query_pos=False):
|
| 218 |
+
super().__init__()
|
| 219 |
+
# Decoder Self-Attention
|
| 220 |
+
self.sa_qcontent_proj = nn.Linear(d_model, d_model)
|
| 221 |
+
self.sa_qpos_proj = nn.Linear(d_model, d_model)
|
| 222 |
+
self.sa_kcontent_proj = nn.Linear(d_model, d_model)
|
| 223 |
+
self.sa_kpos_proj = nn.Linear(d_model, d_model)
|
| 224 |
+
self.sa_v_proj = nn.Linear(d_model, d_model)
|
| 225 |
+
self.sa_out_proj = nn.Linear(d_model, d_model)
|
| 226 |
+
constant_(self.sa_out_proj.bias, 0.)
|
| 227 |
+
|
| 228 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 229 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 230 |
+
|
| 231 |
+
# Decoder Cross-Attention
|
| 232 |
+
self.ca_qcontent_proj = nn.Linear(d_model, d_model)
|
| 233 |
+
self.ca_qpos_proj = nn.Linear(d_model, d_model)
|
| 234 |
+
self.ca_kcontent_proj = nn.Linear(d_model, d_model)
|
| 235 |
+
self.ca_kpos_proj = nn.Linear(d_model, d_model)
|
| 236 |
+
self.ca_v_proj = nn.Linear(d_model, d_model)
|
| 237 |
+
self.ca_qpos_sine_proj = nn.Linear(d_model, d_model)
|
| 238 |
+
self.ca_out_proj = nn.Linear(d_model, d_model)
|
| 239 |
+
constant_(self.ca_out_proj.bias, 0.)
|
| 240 |
+
|
| 241 |
+
self.d_model = d_model
|
| 242 |
+
self.nhead = nhead
|
| 243 |
+
assert self.d_model%self.nhead == 0
|
| 244 |
+
|
| 245 |
+
# Implementation of Feedforward model
|
| 246 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 247 |
+
self.dropout = nn.Dropout(dropout)
|
| 248 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 252 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 253 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 254 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 255 |
+
|
| 256 |
+
self.activation = _get_activation_fn(activation)
|
| 257 |
+
self.keep_query_pos = keep_query_pos
|
| 258 |
+
|
| 259 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
| 260 |
+
return tensor if pos is None else tensor + pos
|
| 261 |
+
|
| 262 |
+
def forward(self, memory, memory_lens, pos,
|
| 263 |
+
tgt, tgt_lens, query_pos, query_sine_embed,
|
| 264 |
+
is_first=False,
|
| 265 |
+
self_attn_bias=None):
|
| 266 |
+
# self_attn_bias is only used for dn_training
|
| 267 |
+
# 'True' indicates that the element should take part in attention
|
| 268 |
+
|
| 269 |
+
B, num_queries = len(tgt_lens), tgt_lens[0]
|
| 270 |
+
L_mem, C_mem = memory.shape
|
| 271 |
+
L_tgt, C_tgt = tgt.shape
|
| 272 |
+
assert C_mem == C_tgt
|
| 273 |
+
|
| 274 |
+
# ========== Begin of Self-Attention =============
|
| 275 |
+
tgt_b4n = tgt
|
| 276 |
+
tgt = self.norm1(tgt)
|
| 277 |
+
|
| 278 |
+
q_content = self.sa_qcontent_proj(tgt)
|
| 279 |
+
q_pos = self.sa_qpos_proj(query_pos)
|
| 280 |
+
k_content = self.sa_kcontent_proj(tgt)
|
| 281 |
+
k_pos = self.sa_kpos_proj(query_pos)
|
| 282 |
+
v = self.sa_v_proj(tgt)
|
| 283 |
+
|
| 284 |
+
q = q_content + q_pos
|
| 285 |
+
k = k_content + k_pos
|
| 286 |
+
|
| 287 |
+
q = q.view(B, num_queries, self.nhead, self.d_model // self.nhead)
|
| 288 |
+
k = k.view(B, num_queries, self.nhead, self.d_model // self.nhead)
|
| 289 |
+
v = v.view(B, num_queries, self.nhead, self.d_model // self.nhead)
|
| 290 |
+
|
| 291 |
+
tgt2 = memory_efficient_attention(q, k, v, attn_bias=self_attn_bias)
|
| 292 |
+
tgt2 = self.sa_out_proj(tgt2.view(L_tgt, self.d_model))
|
| 293 |
+
|
| 294 |
+
tgt = tgt_b4n + self.dropout1(tgt2)
|
| 295 |
+
# ========== End of Self-Attention =============
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# ========== Begin of Cross-Attention =============
|
| 300 |
+
tgt_b4n = tgt
|
| 301 |
+
tgt = self.norm2(tgt)
|
| 302 |
+
|
| 303 |
+
q_content = self.ca_qcontent_proj(tgt)
|
| 304 |
+
k_content = self.ca_kcontent_proj(memory)
|
| 305 |
+
v = self.ca_v_proj(memory)
|
| 306 |
+
|
| 307 |
+
k_pos = self.ca_kpos_proj(pos)
|
| 308 |
+
|
| 309 |
+
# For the first decoder layer, we concatenate the positional embedding predicted from
|
| 310 |
+
# the object query (the positional embedding) into the original query (key) in DETR.
|
| 311 |
+
if is_first or self.keep_query_pos:
|
| 312 |
+
q_pos = self.ca_qpos_proj(query_pos)
|
| 313 |
+
q = q_content + q_pos
|
| 314 |
+
k = k_content + k_pos
|
| 315 |
+
else:
|
| 316 |
+
q = q_content
|
| 317 |
+
k = k_content
|
| 318 |
+
|
| 319 |
+
q = q.view(1, L_tgt, self.nhead, self.d_model//self.nhead)
|
| 320 |
+
query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
|
| 321 |
+
query_sine_embed = query_sine_embed.view(1, L_tgt, self.nhead, self.d_model//self.nhead)
|
| 322 |
+
q = torch.cat([q, query_sine_embed], dim=3)
|
| 323 |
+
|
| 324 |
+
k = k.view(1, L_mem, self.nhead, self.d_model//self.nhead)
|
| 325 |
+
k_pos = k_pos.view(1, L_mem, self.nhead, self.d_model//self.nhead)
|
| 326 |
+
k = torch.cat([k, k_pos], dim=3)
|
| 327 |
+
|
| 328 |
+
v = v.view(1, L_mem, self.nhead, self.d_model//self.nhead)
|
| 329 |
+
|
| 330 |
+
attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens(q_seqlen = tgt_lens, kv_seqlen = memory_lens)
|
| 331 |
+
tgt2 = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 332 |
+
tgt2 = self.ca_out_proj(tgt2.view(L_tgt, self.d_model))
|
| 333 |
+
|
| 334 |
+
tgt = tgt_b4n + self.dropout2(tgt2)
|
| 335 |
+
# ========== End of Cross-Attention =============
|
| 336 |
+
|
| 337 |
+
# FFN
|
| 338 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm3(tgt)))))
|
| 339 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 340 |
+
|
| 341 |
+
return tgt
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def _get_activation_fn(activation):
|
| 345 |
+
"""Return an activation function given a string"""
|
| 346 |
+
if activation == "relu":
|
| 347 |
+
return F.relu
|
| 348 |
+
if activation == "gelu":
|
| 349 |
+
return F.gelu
|
| 350 |
+
if activation == "glu":
|
| 351 |
+
return F.glu
|
| 352 |
+
if activation == "prelu":
|
| 353 |
+
return nn.PReLU()
|
| 354 |
+
if activation == "selu":
|
| 355 |
+
return F.selu
|
| 356 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
| 357 |
+
|
| 358 |
+
def _get_clones(module, N):
|
| 359 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def build_decoder(args):
|
| 363 |
+
return TransformerDecoder(
|
| 364 |
+
d_model=args.hidden_dim,
|
| 365 |
+
dropout=args.dropout,
|
| 366 |
+
nhead=args.nheads,
|
| 367 |
+
num_queries=args.num_queries,
|
| 368 |
+
dim_feedforward=args.dim_feedforward,
|
| 369 |
+
num_decoder_layers=args.dec_layers,
|
| 370 |
+
return_intermediate_dec=True,
|
| 371 |
+
query_dim=4,
|
| 372 |
+
activation=args.transformer_activation
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def torch_attention(query, key, value, attn_bias = None):
|
| 377 |
+
scale = 1.0 / query.shape[-1] ** 0.5
|
| 378 |
+
query = query * scale
|
| 379 |
+
query = query.transpose(1, 2)
|
| 380 |
+
key = key.transpose(1, 2)
|
| 381 |
+
value = value.transpose(1, 2)
|
| 382 |
+
attn = query @ key.transpose(-2, -1)
|
| 383 |
+
if attn_bias is not None:
|
| 384 |
+
attn = attn + attn_bias
|
| 385 |
+
attn = attn.softmax(-1)
|
| 386 |
+
# attn = F.dropout(attn, p)
|
| 387 |
+
attn = attn @ value
|
| 388 |
+
return attn.transpose(1, 2)
|
models/dn_components.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from DINO (https://github.com/IDEA-Research/DINO)
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from utils.misc import (NestedTensor, nested_tensor_from_tensor_list,
|
| 6 |
+
accuracy, get_world_size, interpolate,
|
| 7 |
+
is_dist_avail_and_initialized, inverse_sigmoid)
|
| 8 |
+
from utils import box_ops
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def prepare_for_cdn(targets, dn_cfg, num_queries, hidden_dim, dn_enc):
|
| 13 |
+
"""
|
| 14 |
+
A major difference of DINO from DN-DETR is that the author process pattern embedding pattern embedding in its detector
|
| 15 |
+
forward function and use learnable tgt embedding, so we change this function a little bit.
|
| 16 |
+
:param dn_args: targets, dn_number, label_noise_ratio, box_noise_scale
|
| 17 |
+
:param training: if it is training or inference
|
| 18 |
+
:param num_queries: number of queires
|
| 19 |
+
:param num_classes: number of classes
|
| 20 |
+
:param hidden_dim: transformer hidden dim
|
| 21 |
+
:param label_enc: encode labels in dn
|
| 22 |
+
:return:
|
| 23 |
+
"""
|
| 24 |
+
device = targets[0]['boxes'].device
|
| 25 |
+
|
| 26 |
+
dn_number = dn_cfg['dn_number']
|
| 27 |
+
box_noise_scale = dn_cfg['box_noise_scale']
|
| 28 |
+
tgt_noise_scale = dn_cfg['tgt_noise_scale']
|
| 29 |
+
known = [(torch.ones_like(t['labels'])) for t in targets]
|
| 30 |
+
batch_size = len(known)
|
| 31 |
+
known_num = [sum(k) for k in known]
|
| 32 |
+
|
| 33 |
+
if int(max(known_num)) == 0:
|
| 34 |
+
dn_number = 1
|
| 35 |
+
else:
|
| 36 |
+
if dn_number >= 100:
|
| 37 |
+
dn_number = dn_number // (int(max(known_num) * 2))
|
| 38 |
+
elif dn_number < 1:
|
| 39 |
+
dn_number = 1
|
| 40 |
+
if dn_number == 0:
|
| 41 |
+
dn_number = 1
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
unmask_bbox = torch.cat(known)
|
| 45 |
+
|
| 46 |
+
boxes = torch.cat([t['boxes'] for t in targets])
|
| 47 |
+
assert boxes.ndim == 2
|
| 48 |
+
batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)])
|
| 49 |
+
known_indice = torch.nonzero(unmask_bbox)
|
| 50 |
+
known_indice = known_indice.view(-1)
|
| 51 |
+
known_indice = known_indice.repeat(2 * dn_number, 1).view(-1)
|
| 52 |
+
known_bid = batch_idx.repeat(2 * dn_number, 1).view(-1)
|
| 53 |
+
|
| 54 |
+
single_pad = int(max(known_num))
|
| 55 |
+
pad_size = int(single_pad * 2 * dn_number)
|
| 56 |
+
positive_idx = torch.tensor(range(len(boxes))).long().to(device=device).unsqueeze(0).repeat(dn_number, 1)
|
| 57 |
+
positive_idx += (torch.tensor(range(dn_number)) * len(boxes) * 2).long().to(device=device).unsqueeze(1)
|
| 58 |
+
positive_idx = positive_idx.flatten()
|
| 59 |
+
negative_idx = positive_idx + len(boxes)
|
| 60 |
+
|
| 61 |
+
# box queries
|
| 62 |
+
known_bboxs = boxes.repeat(2 * dn_number, 1)
|
| 63 |
+
known_bbox_expand = known_bboxs.clone()
|
| 64 |
+
if box_noise_scale > 0:
|
| 65 |
+
known_bbox_ = torch.zeros_like(known_bboxs)
|
| 66 |
+
known_bbox_[:, :2] = known_bboxs[:, :2] - known_bboxs[:, 2:] / 2
|
| 67 |
+
known_bbox_[:, 2:] = known_bboxs[:, :2] + known_bboxs[:, 2:] / 2
|
| 68 |
+
|
| 69 |
+
diff = torch.zeros_like(known_bboxs)
|
| 70 |
+
diff[:, :2] = known_bboxs[:, 2:] / 2
|
| 71 |
+
diff[:, 2:] = known_bboxs[:, 2:] / 2
|
| 72 |
+
|
| 73 |
+
rand_sign = torch.randint_like(known_bboxs, low=0, high=2, dtype=torch.float32) * 2.0 - 1.0
|
| 74 |
+
rand_part = torch.rand_like(known_bboxs)
|
| 75 |
+
rand_part[negative_idx] += 1.0
|
| 76 |
+
rand_part *= rand_sign
|
| 77 |
+
known_bbox_ = known_bbox_ + torch.mul(rand_part,
|
| 78 |
+
diff).to(device=device) * box_noise_scale
|
| 79 |
+
known_bbox_ = known_bbox_.clamp(min=0.0, max=1.0)
|
| 80 |
+
known_bbox_expand[:, :2] = (known_bbox_[:, :2] + known_bbox_[:, 2:]) / 2
|
| 81 |
+
known_bbox_expand[:, 2:] = known_bbox_[:, 2:] - known_bbox_[:, :2]
|
| 82 |
+
input_bbox_embed = inverse_sigmoid(known_bbox_expand)
|
| 83 |
+
|
| 84 |
+
# tgt queries
|
| 85 |
+
if dn_cfg['tgt_embed_type'] == 'labels':
|
| 86 |
+
labels = torch.cat([t['labels'] for t in targets])
|
| 87 |
+
known_labels = labels.repeat(2 * dn_number, 1).view(-1)
|
| 88 |
+
known_labels_expaned = known_labels.clone()
|
| 89 |
+
if tgt_noise_scale > 0:
|
| 90 |
+
p = torch.rand_like(known_labels_expaned.float())
|
| 91 |
+
chosen_indice = torch.nonzero(p < tgt_noise_scale).view(-1)
|
| 92 |
+
new_label = torch.randint_like(chosen_indice, 0, dn_cfg['dn_labelbook_size']) # randomly put a new one here
|
| 93 |
+
known_labels_expaned.scatter_(0, chosen_indice, new_label)
|
| 94 |
+
m = known_labels_expaned.long().to(device=device)
|
| 95 |
+
input_tgt_embed = dn_enc(m)
|
| 96 |
+
elif dn_cfg['tgt_embed_type'] == 'params':
|
| 97 |
+
poses = torch.cat([t['poses'] for t in targets])
|
| 98 |
+
betas = torch.cat([t['betas'] for t in targets])
|
| 99 |
+
params = torch.cat([poses, betas], dim=-1)
|
| 100 |
+
assert params.ndim == 2
|
| 101 |
+
known_params = params.repeat(2 * dn_number, 1)
|
| 102 |
+
known_params_expaned = known_params.clone()
|
| 103 |
+
if tgt_noise_scale > 0:
|
| 104 |
+
rand_sign = torch.randint_like(known_params, low=0, high=2, dtype=torch.float32) * 2.0 - 1.0
|
| 105 |
+
rand_part = torch.rand_like(known_params)
|
| 106 |
+
rand_part[negative_idx] += 1.0
|
| 107 |
+
rand_part *= rand_sign
|
| 108 |
+
known_params_expaned = known_params_expaned + rand_part * tgt_noise_scale
|
| 109 |
+
m = known_params_expaned.to(device=device)
|
| 110 |
+
input_tgt_embed = dn_enc(m)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
padding_tgt = torch.zeros((pad_size, hidden_dim), device=device)
|
| 114 |
+
padding_bbox = torch.zeros((pad_size, 4), device=device)
|
| 115 |
+
|
| 116 |
+
input_query_tgt = padding_tgt.repeat(batch_size, 1, 1)
|
| 117 |
+
input_query_bbox = padding_bbox.repeat(batch_size, 1, 1)
|
| 118 |
+
|
| 119 |
+
map_known_indice = torch.tensor([]).to(device=device)
|
| 120 |
+
if len(known_num):
|
| 121 |
+
map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num]) # [1,2, 1,2,3]
|
| 122 |
+
map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(2 * dn_number)]).long()
|
| 123 |
+
if len(known_bid):
|
| 124 |
+
input_query_tgt[(known_bid.long(), map_known_indice)] = input_tgt_embed
|
| 125 |
+
input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# prepare attn_mask
|
| 129 |
+
tgt_size = pad_size + num_queries
|
| 130 |
+
attn_mask = torch.zeros((tgt_size, tgt_size), dtype=bool, device=device)
|
| 131 |
+
# match query cannot see the reconstruct
|
| 132 |
+
attn_mask[pad_size:, :pad_size] = True
|
| 133 |
+
# reconstruct cannot see each other
|
| 134 |
+
for i in range(dn_number):
|
| 135 |
+
if i == 0:
|
| 136 |
+
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True
|
| 137 |
+
if i == dn_number - 1:
|
| 138 |
+
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * i * 2] = True
|
| 139 |
+
else:
|
| 140 |
+
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True
|
| 141 |
+
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * 2 * i] = True
|
| 142 |
+
|
| 143 |
+
dn_meta = {
|
| 144 |
+
'pad_size': pad_size,
|
| 145 |
+
'num_dn_group': dn_number,
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
return input_query_tgt, input_query_bbox, attn_mask, dn_meta
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def dn_post_process(pred_poses, pred_betas,
|
| 152 |
+
pred_boxes, pred_confs,
|
| 153 |
+
pred_j3ds, pred_j2ds, pred_depths,
|
| 154 |
+
pred_verts, pred_transl,
|
| 155 |
+
dn_meta, aux_loss, _set_aux_loss):
|
| 156 |
+
"""
|
| 157 |
+
post process of dn after output from the transformer
|
| 158 |
+
put the dn part in the dn_meta
|
| 159 |
+
"""
|
| 160 |
+
assert dn_meta['pad_size'] > 0
|
| 161 |
+
pad_size = dn_meta['pad_size']
|
| 162 |
+
|
| 163 |
+
known_poses, pred_poses = pred_poses[:,:,:pad_size], pred_poses[:,:,pad_size:]
|
| 164 |
+
known_betas, pred_betas = pred_betas[:,:,:pad_size], pred_betas[:,:,pad_size:]
|
| 165 |
+
known_boxes, pred_boxes = pred_boxes[:,:,:pad_size], pred_boxes[:,:,pad_size:]
|
| 166 |
+
known_confs, pred_confs = pred_confs[:,:,:pad_size], pred_confs[:,:,pad_size:]
|
| 167 |
+
known_j3ds, pred_j3ds = pred_j3ds[:,:,:pad_size], pred_j3ds[:,:,pad_size:]
|
| 168 |
+
known_j2ds, pred_j2ds = pred_j2ds[:,:,:pad_size], pred_j2ds[:,:,pad_size:]
|
| 169 |
+
known_depths, pred_depths = pred_depths[:,:,:pad_size], pred_depths[:,:,pad_size:]
|
| 170 |
+
|
| 171 |
+
known_verts, pred_verts = pred_verts[:,:pad_size], pred_verts[:,pad_size:]
|
| 172 |
+
known_transl, pred_transl = pred_transl[:,:pad_size], pred_transl[:,pad_size:]
|
| 173 |
+
|
| 174 |
+
out = {'pred_poses': known_poses[-1], 'pred_betas': known_betas[-1],
|
| 175 |
+
'pred_boxes': known_boxes[-1], 'pred_confs': known_confs[-1],
|
| 176 |
+
'pred_j3ds': known_j3ds[-1], 'pred_j2ds': known_j2ds[-1],
|
| 177 |
+
'pred_depths': known_depths[-1]}
|
| 178 |
+
|
| 179 |
+
if aux_loss:
|
| 180 |
+
out['aux_outputs'] = _set_aux_loss(known_poses, known_betas,
|
| 181 |
+
known_boxes, known_confs,
|
| 182 |
+
known_j3ds, known_j2ds, known_depths)
|
| 183 |
+
|
| 184 |
+
dn_meta['output_known'] = out
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
return pred_poses, pred_betas,\
|
| 188 |
+
pred_boxes, pred_confs,\
|
| 189 |
+
pred_j3ds, pred_j2ds,\
|
| 190 |
+
pred_depths, pred_verts,\
|
| 191 |
+
pred_transl,
|
| 192 |
+
|
| 193 |
+
|
models/encoders/__init__.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from DINOv2 (https://github.com/facebookresearch/dinov2)
|
| 2 |
+
from models.encoders.dinov2.models.vision_transformer import vit_base, vit_large
|
| 3 |
+
import torch
|
| 4 |
+
from configs.paths import dinov2_vitb14_path, dinov2_vitl14_path
|
| 5 |
+
import copy
|
| 6 |
+
|
| 7 |
+
def build_encoder(args):
|
| 8 |
+
num_additional_blocks = 0
|
| 9 |
+
if args.sat_cfg['use_sat'] and args.sat_cfg['use_additional_blocks']:
|
| 10 |
+
num_additional_blocks = args.sat_cfg['get_map_layer']
|
| 11 |
+
|
| 12 |
+
weights = None
|
| 13 |
+
if args.encoder == 'vitb':
|
| 14 |
+
model = vit_base(img_size = 518,
|
| 15 |
+
patch_size = 14,
|
| 16 |
+
init_values = 1.0,
|
| 17 |
+
ffn_layer = "mlp",
|
| 18 |
+
block_chunks = 0,
|
| 19 |
+
num_register_tokens = 0,
|
| 20 |
+
interpolate_antialias = False,
|
| 21 |
+
interpolate_offset = 0.1,
|
| 22 |
+
num_additional_blocks = num_additional_blocks)
|
| 23 |
+
if args.mode.lower() == 'train':
|
| 24 |
+
weights = torch.load(dinov2_vitb14_path)
|
| 25 |
+
elif args.encoder == 'vitl':
|
| 26 |
+
model = vit_large(img_size = 518,
|
| 27 |
+
patch_size = 14,
|
| 28 |
+
init_values = 1.0,
|
| 29 |
+
ffn_layer = "mlp",
|
| 30 |
+
block_chunks = 0,
|
| 31 |
+
num_register_tokens = 0,
|
| 32 |
+
interpolate_antialias = False,
|
| 33 |
+
interpolate_offset = 0.1,
|
| 34 |
+
num_additional_blocks = num_additional_blocks)
|
| 35 |
+
if args.mode.lower() == 'train':
|
| 36 |
+
weights = torch.load(dinov2_vitl14_path)
|
| 37 |
+
else:
|
| 38 |
+
raise NotImplementedError
|
| 39 |
+
|
| 40 |
+
if weights is not None:
|
| 41 |
+
if args.sat_cfg['use_sat'] and args.sat_cfg['use_additional_blocks']:
|
| 42 |
+
add_blocks_weights(weights, args.sat_cfg['get_map_layer'])
|
| 43 |
+
print('Loading pretrained DINOv2...')
|
| 44 |
+
model.load_state_dict(weights,strict=True)
|
| 45 |
+
|
| 46 |
+
return model
|
| 47 |
+
|
| 48 |
+
def add_blocks_weights(weights, num_layers):
|
| 49 |
+
for k in list(weights.keys()):
|
| 50 |
+
if k.startswith('blocks') and int(k.split('.')[1]) < num_layers:
|
| 51 |
+
new_k = k.replace('blocks', 'additional_blocks')
|
| 52 |
+
weights[new_k] = copy.deepcopy(weights[k])
|
models/encoders/dinov2/layers/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .dino_head import DINOHead
|
| 7 |
+
from .mlp import Mlp
|
| 8 |
+
from .patch_embed import PatchEmbed
|
| 9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
| 10 |
+
from .block import NestedTensorBlock
|
| 11 |
+
from .attention import MemEffAttention
|
models/encoders/dinov2/layers/attention.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
from torch import nn
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger("dinov2")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 22 |
+
try:
|
| 23 |
+
if XFORMERS_ENABLED:
|
| 24 |
+
from xformers.ops import memory_efficient_attention, unbind
|
| 25 |
+
|
| 26 |
+
XFORMERS_AVAILABLE = True
|
| 27 |
+
warnings.warn("xFormers is available (Attention)")
|
| 28 |
+
else:
|
| 29 |
+
warnings.warn("xFormers is disabled (Attention)")
|
| 30 |
+
raise ImportError
|
| 31 |
+
except ImportError:
|
| 32 |
+
XFORMERS_AVAILABLE = False
|
| 33 |
+
warnings.warn("xFormers is not available (Attention)")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Attention(nn.Module):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
dim: int,
|
| 40 |
+
num_heads: int = 8,
|
| 41 |
+
qkv_bias: bool = False,
|
| 42 |
+
proj_bias: bool = True,
|
| 43 |
+
attn_drop: float = 0.0,
|
| 44 |
+
proj_drop: float = 0.0,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.num_heads = num_heads
|
| 48 |
+
head_dim = dim // num_heads
|
| 49 |
+
self.scale = head_dim**-0.5
|
| 50 |
+
|
| 51 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 52 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 53 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 54 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 55 |
+
|
| 56 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 57 |
+
B, N, C = x.shape
|
| 58 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 59 |
+
|
| 60 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 61 |
+
attn = q @ k.transpose(-2, -1)
|
| 62 |
+
|
| 63 |
+
attn = attn.softmax(dim=-1)
|
| 64 |
+
attn = self.attn_drop(attn)
|
| 65 |
+
|
| 66 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 67 |
+
x = self.proj(x)
|
| 68 |
+
x = self.proj_drop(x)
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class MemEffAttention(Attention):
|
| 73 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 74 |
+
if not XFORMERS_AVAILABLE:
|
| 75 |
+
if attn_bias is not None:
|
| 76 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 77 |
+
return super().forward(x)
|
| 78 |
+
|
| 79 |
+
B, N, C = x.shape
|
| 80 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 81 |
+
|
| 82 |
+
q, k, v = unbind(qkv, 2)
|
| 83 |
+
|
| 84 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 85 |
+
x = x.reshape([B, N, C])
|
| 86 |
+
|
| 87 |
+
x = self.proj(x)
|
| 88 |
+
x = self.proj_drop(x)
|
| 89 |
+
return x
|
models/encoders/dinov2/layers/block.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
| 13 |
+
import warnings
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn, Tensor
|
| 17 |
+
|
| 18 |
+
from .attention import Attention, MemEffAttention
|
| 19 |
+
from .drop_path import DropPath
|
| 20 |
+
from .layer_scale import LayerScale
|
| 21 |
+
from .mlp import Mlp
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger("dinov2")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 28 |
+
try:
|
| 29 |
+
if XFORMERS_ENABLED:
|
| 30 |
+
from xformers.ops import fmha, scaled_index_add, index_select_cat
|
| 31 |
+
|
| 32 |
+
XFORMERS_AVAILABLE = True
|
| 33 |
+
warnings.warn("xFormers is available (Block)")
|
| 34 |
+
else:
|
| 35 |
+
warnings.warn("xFormers is disabled (Block)")
|
| 36 |
+
raise ImportError
|
| 37 |
+
except ImportError:
|
| 38 |
+
XFORMERS_AVAILABLE = False
|
| 39 |
+
|
| 40 |
+
warnings.warn("xFormers is not available (Block)")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Block(nn.Module):
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
dim: int,
|
| 47 |
+
num_heads: int,
|
| 48 |
+
mlp_ratio: float = 4.0,
|
| 49 |
+
qkv_bias: bool = False,
|
| 50 |
+
proj_bias: bool = True,
|
| 51 |
+
ffn_bias: bool = True,
|
| 52 |
+
drop: float = 0.0,
|
| 53 |
+
attn_drop: float = 0.0,
|
| 54 |
+
init_values=None,
|
| 55 |
+
drop_path: float = 0.0,
|
| 56 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 57 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 58 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 59 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 60 |
+
) -> None:
|
| 61 |
+
super().__init__()
|
| 62 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 63 |
+
self.norm1 = norm_layer(dim)
|
| 64 |
+
self.attn = attn_class(
|
| 65 |
+
dim,
|
| 66 |
+
num_heads=num_heads,
|
| 67 |
+
qkv_bias=qkv_bias,
|
| 68 |
+
proj_bias=proj_bias,
|
| 69 |
+
attn_drop=attn_drop,
|
| 70 |
+
proj_drop=drop,
|
| 71 |
+
)
|
| 72 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 73 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 74 |
+
|
| 75 |
+
self.norm2 = norm_layer(dim)
|
| 76 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 77 |
+
self.mlp = ffn_layer(
|
| 78 |
+
in_features=dim,
|
| 79 |
+
hidden_features=mlp_hidden_dim,
|
| 80 |
+
act_layer=act_layer,
|
| 81 |
+
drop=drop,
|
| 82 |
+
bias=ffn_bias,
|
| 83 |
+
)
|
| 84 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 85 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 86 |
+
|
| 87 |
+
self.sample_drop_ratio = drop_path
|
| 88 |
+
|
| 89 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 90 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
| 91 |
+
return self.ls1(self.attn(self.norm1(x)))
|
| 92 |
+
|
| 93 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 94 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 95 |
+
|
| 96 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 97 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 98 |
+
x = drop_add_residual_stochastic_depth(
|
| 99 |
+
x,
|
| 100 |
+
residual_func=attn_residual_func,
|
| 101 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 102 |
+
)
|
| 103 |
+
x = drop_add_residual_stochastic_depth(
|
| 104 |
+
x,
|
| 105 |
+
residual_func=ffn_residual_func,
|
| 106 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 107 |
+
)
|
| 108 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 109 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
| 110 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 111 |
+
else:
|
| 112 |
+
x = x + attn_residual_func(x)
|
| 113 |
+
x = x + ffn_residual_func(x)
|
| 114 |
+
return x
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def drop_add_residual_stochastic_depth(
|
| 118 |
+
x: Tensor,
|
| 119 |
+
residual_func: Callable[[Tensor], Tensor],
|
| 120 |
+
sample_drop_ratio: float = 0.0,
|
| 121 |
+
) -> Tensor:
|
| 122 |
+
# 1) extract subset using permutation
|
| 123 |
+
b, n, d = x.shape
|
| 124 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 125 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 126 |
+
x_subset = x[brange]
|
| 127 |
+
|
| 128 |
+
# 2) apply residual_func to get residual
|
| 129 |
+
residual = residual_func(x_subset)
|
| 130 |
+
|
| 131 |
+
x_flat = x.flatten(1)
|
| 132 |
+
residual = residual.flatten(1)
|
| 133 |
+
|
| 134 |
+
residual_scale_factor = b / sample_subset_size
|
| 135 |
+
|
| 136 |
+
# 3) add the residual
|
| 137 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 138 |
+
return x_plus_residual.view_as(x)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 142 |
+
b, n, d = x.shape
|
| 143 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 144 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 145 |
+
residual_scale_factor = b / sample_subset_size
|
| 146 |
+
return brange, residual_scale_factor
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 150 |
+
if scaling_vector is None:
|
| 151 |
+
x_flat = x.flatten(1)
|
| 152 |
+
residual = residual.flatten(1)
|
| 153 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 154 |
+
else:
|
| 155 |
+
x_plus_residual = scaled_index_add(
|
| 156 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
| 157 |
+
)
|
| 158 |
+
return x_plus_residual
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
| 165 |
+
"""
|
| 166 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 167 |
+
"""
|
| 168 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
| 169 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 170 |
+
if all_shapes not in attn_bias_cache.keys():
|
| 171 |
+
seqlens = []
|
| 172 |
+
for b, x in zip(batch_sizes, x_list):
|
| 173 |
+
for _ in range(b):
|
| 174 |
+
seqlens.append(x.shape[1])
|
| 175 |
+
attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens(seqlens)
|
| 176 |
+
attn_bias._batch_sizes = batch_sizes
|
| 177 |
+
attn_bias_cache[all_shapes] = attn_bias
|
| 178 |
+
|
| 179 |
+
if branges is not None:
|
| 180 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
| 181 |
+
else:
|
| 182 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 183 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 184 |
+
|
| 185 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def drop_add_residual_stochastic_depth_list(
|
| 189 |
+
x_list: List[Tensor],
|
| 190 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
| 191 |
+
sample_drop_ratio: float = 0.0,
|
| 192 |
+
scaling_vector=None,
|
| 193 |
+
) -> Tensor:
|
| 194 |
+
# 1) generate random set of indices for dropping samples in the batch
|
| 195 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
| 196 |
+
branges = [s[0] for s in branges_scales]
|
| 197 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
| 198 |
+
|
| 199 |
+
# 2) get attention bias and index+concat the tensors
|
| 200 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 201 |
+
|
| 202 |
+
# 3) apply residual_func to get residual, and split the result
|
| 203 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 204 |
+
|
| 205 |
+
outputs = []
|
| 206 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
| 207 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
| 208 |
+
return outputs
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class NestedTensorBlock(Block):
|
| 212 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 213 |
+
"""
|
| 214 |
+
x_list contains a list of tensors to nest together and run
|
| 215 |
+
"""
|
| 216 |
+
assert isinstance(self.attn, MemEffAttention)
|
| 217 |
+
|
| 218 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 219 |
+
|
| 220 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 221 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 222 |
+
|
| 223 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 224 |
+
return self.mlp(self.norm2(x))
|
| 225 |
+
|
| 226 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 227 |
+
x_list,
|
| 228 |
+
residual_func=attn_residual_func,
|
| 229 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 230 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 231 |
+
)
|
| 232 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 233 |
+
x_list,
|
| 234 |
+
residual_func=ffn_residual_func,
|
| 235 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 236 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 237 |
+
)
|
| 238 |
+
return x_list
|
| 239 |
+
else:
|
| 240 |
+
|
| 241 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 242 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 243 |
+
|
| 244 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 245 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 246 |
+
|
| 247 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 248 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 249 |
+
x = x + ffn_residual_func(x)
|
| 250 |
+
return attn_bias.split(x)
|
| 251 |
+
|
| 252 |
+
def forward(self, x_or_x_list):
|
| 253 |
+
if isinstance(x_or_x_list, Tensor):
|
| 254 |
+
return super().forward(x_or_x_list)
|
| 255 |
+
elif isinstance(x_or_x_list, list):
|
| 256 |
+
if not XFORMERS_AVAILABLE:
|
| 257 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 258 |
+
return self.forward_nested(x_or_x_list)
|
| 259 |
+
else:
|
| 260 |
+
raise AssertionError
|
models/encoders/dinov2/layers/dino_head.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.nn.init import trunc_normal_
|
| 9 |
+
from torch.nn.utils import weight_norm
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DINOHead(nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
in_dim,
|
| 16 |
+
out_dim,
|
| 17 |
+
use_bn=False,
|
| 18 |
+
nlayers=3,
|
| 19 |
+
hidden_dim=2048,
|
| 20 |
+
bottleneck_dim=256,
|
| 21 |
+
mlp_bias=True,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
nlayers = max(nlayers, 1)
|
| 25 |
+
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
|
| 26 |
+
self.apply(self._init_weights)
|
| 27 |
+
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
| 28 |
+
self.last_layer.weight_g.data.fill_(1)
|
| 29 |
+
|
| 30 |
+
def _init_weights(self, m):
|
| 31 |
+
if isinstance(m, nn.Linear):
|
| 32 |
+
trunc_normal_(m.weight, std=0.02)
|
| 33 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 34 |
+
nn.init.constant_(m.bias, 0)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
x = self.mlp(x)
|
| 38 |
+
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
| 39 |
+
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
| 40 |
+
x = self.last_layer(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
|
| 45 |
+
if nlayers == 1:
|
| 46 |
+
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
| 47 |
+
else:
|
| 48 |
+
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
| 49 |
+
if use_bn:
|
| 50 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 51 |
+
layers.append(nn.GELU())
|
| 52 |
+
for _ in range(nlayers - 2):
|
| 53 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
| 54 |
+
if use_bn:
|
| 55 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 56 |
+
layers.append(nn.GELU())
|
| 57 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
| 58 |
+
return nn.Sequential(*layers)
|
models/encoders/dinov2/layers/drop_path.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 15 |
+
if drop_prob == 0.0 or not training:
|
| 16 |
+
return x
|
| 17 |
+
keep_prob = 1 - drop_prob
|
| 18 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 19 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 20 |
+
if keep_prob > 0.0:
|
| 21 |
+
random_tensor.div_(keep_prob)
|
| 22 |
+
output = x * random_tensor
|
| 23 |
+
return output
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DropPath(nn.Module):
|
| 27 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, drop_prob=None):
|
| 30 |
+
super(DropPath, self).__init__()
|
| 31 |
+
self.drop_prob = drop_prob
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
return drop_path(x, self.drop_prob, self.training)
|
models/encoders/dinov2/layers/layer_scale.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
| 7 |
+
|
| 8 |
+
from typing import Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LayerScale(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
dim: int,
|
| 19 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 20 |
+
inplace: bool = False,
|
| 21 |
+
) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.inplace = inplace
|
| 24 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 25 |
+
|
| 26 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 27 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
models/encoders/dinov2/layers/mlp.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from typing import Callable, Optional
|
| 12 |
+
|
| 13 |
+
from torch import Tensor, nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Mlp(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_features: int,
|
| 20 |
+
hidden_features: Optional[int] = None,
|
| 21 |
+
out_features: Optional[int] = None,
|
| 22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 23 |
+
drop: float = 0.0,
|
| 24 |
+
bias: bool = True,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
out_features = out_features or in_features
|
| 28 |
+
hidden_features = hidden_features or in_features
|
| 29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 30 |
+
self.act = act_layer()
|
| 31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 32 |
+
self.drop = nn.Dropout(drop)
|
| 33 |
+
|
| 34 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 35 |
+
x = self.fc1(x)
|
| 36 |
+
x = self.act(x)
|
| 37 |
+
x = self.drop(x)
|
| 38 |
+
x = self.fc2(x)
|
| 39 |
+
x = self.drop(x)
|
| 40 |
+
return x
|
models/encoders/dinov2/layers/patch_embed.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
from typing import Callable, Optional, Tuple, Union
|
| 11 |
+
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def make_2tuple(x):
|
| 17 |
+
if isinstance(x, tuple):
|
| 18 |
+
assert len(x) == 2
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
assert isinstance(x, int)
|
| 22 |
+
return (x, x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PatchEmbed(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
img_size: Image size.
|
| 31 |
+
patch_size: Patch token size.
|
| 32 |
+
in_chans: Number of input image channels.
|
| 33 |
+
embed_dim: Number of linear projection output channels.
|
| 34 |
+
norm_layer: Normalization layer.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 40 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 41 |
+
in_chans: int = 3,
|
| 42 |
+
embed_dim: int = 768,
|
| 43 |
+
norm_layer: Optional[Callable] = None,
|
| 44 |
+
flatten_embedding: bool = True,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
image_HW = make_2tuple(img_size)
|
| 49 |
+
patch_HW = make_2tuple(patch_size)
|
| 50 |
+
patch_grid_size = (
|
| 51 |
+
image_HW[0] // patch_HW[0],
|
| 52 |
+
image_HW[1] // patch_HW[1],
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.img_size = image_HW
|
| 56 |
+
self.patch_size = patch_HW
|
| 57 |
+
self.patches_resolution = patch_grid_size
|
| 58 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 59 |
+
|
| 60 |
+
self.in_chans = in_chans
|
| 61 |
+
self.embed_dim = embed_dim
|
| 62 |
+
|
| 63 |
+
self.flatten_embedding = flatten_embedding
|
| 64 |
+
|
| 65 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
| 66 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 67 |
+
|
| 68 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 69 |
+
_, _, H, W = x.shape
|
| 70 |
+
patch_H, patch_W = self.patch_size
|
| 71 |
+
|
| 72 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 73 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 74 |
+
|
| 75 |
+
x = self.proj(x) # B C H W
|
| 76 |
+
H, W = x.size(2), x.size(3)
|
| 77 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 78 |
+
x = self.norm(x)
|
| 79 |
+
if not self.flatten_embedding:
|
| 80 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
def flops(self) -> float:
|
| 84 |
+
Ho, Wo = self.patches_resolution
|
| 85 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 86 |
+
if self.norm is not None:
|
| 87 |
+
flops += Ho * Wo * self.embed_dim
|
| 88 |
+
return flops
|
models/encoders/dinov2/layers/swiglu_ffn.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Callable, Optional
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
from torch import Tensor, nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SwiGLUFFN(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
in_features: int,
|
| 18 |
+
hidden_features: Optional[int] = None,
|
| 19 |
+
out_features: Optional[int] = None,
|
| 20 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 21 |
+
drop: float = 0.0,
|
| 22 |
+
bias: bool = True,
|
| 23 |
+
) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
out_features = out_features or in_features
|
| 26 |
+
hidden_features = hidden_features or in_features
|
| 27 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 28 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 29 |
+
|
| 30 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 31 |
+
x12 = self.w12(x)
|
| 32 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 33 |
+
hidden = F.silu(x1) * x2
|
| 34 |
+
return self.w3(hidden)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 38 |
+
try:
|
| 39 |
+
if XFORMERS_ENABLED:
|
| 40 |
+
from xformers.ops import SwiGLU
|
| 41 |
+
|
| 42 |
+
XFORMERS_AVAILABLE = True
|
| 43 |
+
warnings.warn("xFormers is available (SwiGLU)")
|
| 44 |
+
else:
|
| 45 |
+
warnings.warn("xFormers is disabled (SwiGLU)")
|
| 46 |
+
raise ImportError
|
| 47 |
+
except ImportError:
|
| 48 |
+
SwiGLU = SwiGLUFFN
|
| 49 |
+
XFORMERS_AVAILABLE = False
|
| 50 |
+
|
| 51 |
+
warnings.warn("xFormers is not available (SwiGLU)")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SwiGLUFFNFused(SwiGLU):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
in_features: int,
|
| 58 |
+
hidden_features: Optional[int] = None,
|
| 59 |
+
out_features: Optional[int] = None,
|
| 60 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 61 |
+
drop: float = 0.0,
|
| 62 |
+
bias: bool = True,
|
| 63 |
+
) -> None:
|
| 64 |
+
out_features = out_features or in_features
|
| 65 |
+
hidden_features = hidden_features or in_features
|
| 66 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 67 |
+
super().__init__(
|
| 68 |
+
in_features=in_features,
|
| 69 |
+
hidden_features=hidden_features,
|
| 70 |
+
out_features=out_features,
|
| 71 |
+
bias=bias,
|
| 72 |
+
)
|
models/encoders/dinov2/models/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
from . import vision_transformer as vits
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger("dinov2")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def build_model(args, only_teacher=False, img_size=224):
|
| 15 |
+
args.arch = args.arch.removesuffix("_memeff")
|
| 16 |
+
if "vit" in args.arch:
|
| 17 |
+
vit_kwargs = dict(
|
| 18 |
+
img_size=img_size,
|
| 19 |
+
patch_size=args.patch_size,
|
| 20 |
+
init_values=args.layerscale,
|
| 21 |
+
ffn_layer=args.ffn_layer,
|
| 22 |
+
block_chunks=args.block_chunks,
|
| 23 |
+
qkv_bias=args.qkv_bias,
|
| 24 |
+
proj_bias=args.proj_bias,
|
| 25 |
+
ffn_bias=args.ffn_bias,
|
| 26 |
+
num_register_tokens=args.num_register_tokens,
|
| 27 |
+
interpolate_offset=args.interpolate_offset,
|
| 28 |
+
interpolate_antialias=args.interpolate_antialias,
|
| 29 |
+
)
|
| 30 |
+
teacher = vits.__dict__[args.arch](**vit_kwargs)
|
| 31 |
+
if only_teacher:
|
| 32 |
+
return teacher, teacher.embed_dim
|
| 33 |
+
student = vits.__dict__[args.arch](
|
| 34 |
+
**vit_kwargs,
|
| 35 |
+
drop_path_rate=args.drop_path_rate,
|
| 36 |
+
drop_path_uniform=args.drop_path_uniform,
|
| 37 |
+
)
|
| 38 |
+
embed_dim = student.embed_dim
|
| 39 |
+
return student, teacher, embed_dim
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def build_model_from_cfg(cfg, only_teacher=False):
|
| 43 |
+
return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
|
models/encoders/dinov2/models/vision_transformer.py
ADDED
|
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from DINOv2 (https://github.com/facebookresearch/dinov2)
|
| 2 |
+
# ------------------------------------------------------------------------
|
| 3 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 6 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
# References:
|
| 9 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
| 10 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 11 |
+
|
| 12 |
+
from functools import partial
|
| 13 |
+
import math
|
| 14 |
+
import logging
|
| 15 |
+
from typing import Sequence, Tuple, Union, Callable
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.utils.checkpoint
|
| 20 |
+
from torch.nn.init import trunc_normal_
|
| 21 |
+
|
| 22 |
+
import copy
|
| 23 |
+
|
| 24 |
+
from models.encoders.dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger("dinov2")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
| 31 |
+
if not depth_first and include_root:
|
| 32 |
+
fn(module=module, name=name)
|
| 33 |
+
for child_name, child_module in module.named_children():
|
| 34 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 35 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
| 36 |
+
if depth_first and include_root:
|
| 37 |
+
fn(module=module, name=name)
|
| 38 |
+
return module
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class BlockChunk(nn.ModuleList):
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
for b in self:
|
| 44 |
+
x = b(x)
|
| 45 |
+
return x
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class DinoVisionTransformer(nn.Module):
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
img_size=224,
|
| 52 |
+
patch_size=16,
|
| 53 |
+
in_chans=3,
|
| 54 |
+
embed_dim=768,
|
| 55 |
+
depth=12,
|
| 56 |
+
num_heads=12,
|
| 57 |
+
mlp_ratio=4.0,
|
| 58 |
+
qkv_bias=True,
|
| 59 |
+
ffn_bias=True,
|
| 60 |
+
proj_bias=True,
|
| 61 |
+
drop_path_rate=0.0,
|
| 62 |
+
drop_path_uniform=False,
|
| 63 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
| 64 |
+
embed_layer=PatchEmbed,
|
| 65 |
+
act_layer=nn.GELU,
|
| 66 |
+
block_fn=Block,
|
| 67 |
+
ffn_layer="mlp",
|
| 68 |
+
block_chunks=1,
|
| 69 |
+
num_register_tokens=0,
|
| 70 |
+
interpolate_antialias=False,
|
| 71 |
+
interpolate_offset=0.1,
|
| 72 |
+
num_additional_blocks = 0,
|
| 73 |
+
):
|
| 74 |
+
"""
|
| 75 |
+
Args:
|
| 76 |
+
img_size (int, tuple): input image size
|
| 77 |
+
patch_size (int, tuple): patch size
|
| 78 |
+
in_chans (int): number of input channels
|
| 79 |
+
embed_dim (int): embedding dimension
|
| 80 |
+
depth (int): depth of transformer
|
| 81 |
+
num_heads (int): number of attention heads
|
| 82 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 83 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 84 |
+
proj_bias (bool): enable bias for proj in attn if True
|
| 85 |
+
ffn_bias (bool): enable bias for ffn if True
|
| 86 |
+
drop_path_rate (float): stochastic depth rate
|
| 87 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
| 88 |
+
weight_init (str): weight init scheme
|
| 89 |
+
init_values (float): layer-scale init values
|
| 90 |
+
embed_layer (nn.Module): patch embedding layer
|
| 91 |
+
act_layer (nn.Module): MLP activation layer
|
| 92 |
+
block_fn (nn.Module): transformer block class
|
| 93 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
| 94 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
| 95 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
| 96 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
| 97 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
| 98 |
+
"""
|
| 99 |
+
super().__init__()
|
| 100 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 101 |
+
|
| 102 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 103 |
+
self.num_tokens = 1
|
| 104 |
+
self.n_blocks = depth
|
| 105 |
+
self.num_heads = num_heads
|
| 106 |
+
self.patch_size = patch_size
|
| 107 |
+
self.num_register_tokens = num_register_tokens
|
| 108 |
+
self.interpolate_antialias = interpolate_antialias
|
| 109 |
+
self.interpolate_offset = interpolate_offset
|
| 110 |
+
self.img_size = img_size
|
| 111 |
+
|
| 112 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 113 |
+
num_patches = self.patch_embed.num_patches
|
| 114 |
+
|
| 115 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 116 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
| 117 |
+
assert num_register_tokens >= 0
|
| 118 |
+
self.register_tokens = (
|
| 119 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
if drop_path_uniform is True:
|
| 123 |
+
dpr = [drop_path_rate] * depth
|
| 124 |
+
else:
|
| 125 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 126 |
+
|
| 127 |
+
if ffn_layer == "mlp":
|
| 128 |
+
logger.info("using MLP layer as FFN")
|
| 129 |
+
ffn_layer = Mlp
|
| 130 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
| 131 |
+
logger.info("using SwiGLU layer as FFN")
|
| 132 |
+
ffn_layer = SwiGLUFFNFused
|
| 133 |
+
elif ffn_layer == "identity":
|
| 134 |
+
logger.info("using Identity layer as FFN")
|
| 135 |
+
|
| 136 |
+
def f(*args, **kwargs):
|
| 137 |
+
return nn.Identity()
|
| 138 |
+
|
| 139 |
+
ffn_layer = f
|
| 140 |
+
else:
|
| 141 |
+
raise NotImplementedError
|
| 142 |
+
|
| 143 |
+
blocks_list = [
|
| 144 |
+
block_fn(
|
| 145 |
+
dim=embed_dim,
|
| 146 |
+
num_heads=num_heads,
|
| 147 |
+
mlp_ratio=mlp_ratio,
|
| 148 |
+
qkv_bias=qkv_bias,
|
| 149 |
+
proj_bias=proj_bias,
|
| 150 |
+
ffn_bias=ffn_bias,
|
| 151 |
+
drop_path=dpr[i],
|
| 152 |
+
norm_layer=norm_layer,
|
| 153 |
+
act_layer=act_layer,
|
| 154 |
+
ffn_layer=ffn_layer,
|
| 155 |
+
init_values=init_values,
|
| 156 |
+
)
|
| 157 |
+
for i in range(depth)
|
| 158 |
+
]
|
| 159 |
+
if block_chunks > 0:
|
| 160 |
+
self.chunked_blocks = True
|
| 161 |
+
chunked_blocks = []
|
| 162 |
+
chunksize = depth // block_chunks
|
| 163 |
+
for i in range(0, depth, chunksize):
|
| 164 |
+
# this is to keep the block index consistent if we chunk the block list
|
| 165 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
| 166 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
| 167 |
+
else:
|
| 168 |
+
self.chunked_blocks = False
|
| 169 |
+
self.blocks = nn.ModuleList(blocks_list)
|
| 170 |
+
|
| 171 |
+
if num_additional_blocks > 0:
|
| 172 |
+
assert not self.chunked_blocks
|
| 173 |
+
self.additional_blocks = copy.deepcopy(self.blocks[:num_additional_blocks])
|
| 174 |
+
|
| 175 |
+
self.norm = norm_layer(embed_dim)
|
| 176 |
+
self.head = nn.Identity()
|
| 177 |
+
|
| 178 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
| 179 |
+
|
| 180 |
+
self.init_weights()
|
| 181 |
+
|
| 182 |
+
def init_weights(self):
|
| 183 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 184 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 185 |
+
if self.register_tokens is not None:
|
| 186 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
| 187 |
+
named_apply(init_weights_vit_timm, self)
|
| 188 |
+
|
| 189 |
+
def interpolate_pos_encoding(self, x, w, h, with_cls_token = True):
|
| 190 |
+
previous_dtype = x.dtype
|
| 191 |
+
# npatch = x.shape[1] - 1 if with_cls_token else x.shape[1]
|
| 192 |
+
N = self.pos_embed.shape[1] - 1
|
| 193 |
+
# if npatch == N and w == h:
|
| 194 |
+
# return self.pos_embed
|
| 195 |
+
pos_embed = self.pos_embed.float()
|
| 196 |
+
class_pos_embed = pos_embed[:, 0]
|
| 197 |
+
patch_pos_embed = pos_embed[:, 1:]
|
| 198 |
+
dim = x.shape[-1]
|
| 199 |
+
w0 = w // self.patch_size
|
| 200 |
+
h0 = h // self.patch_size
|
| 201 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
| 202 |
+
assert N == M * M
|
| 203 |
+
kwargs = {}
|
| 204 |
+
if self.interpolate_offset:
|
| 205 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
| 206 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
| 207 |
+
sx = float(w0 + self.interpolate_offset) / M
|
| 208 |
+
sy = float(h0 + self.interpolate_offset) / M
|
| 209 |
+
kwargs["scale_factor"] = (sx, sy)
|
| 210 |
+
else:
|
| 211 |
+
# Simply specify an output size instead of a scale factor
|
| 212 |
+
kwargs["size"] = (w0, h0)
|
| 213 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 214 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
| 215 |
+
mode="bicubic",
|
| 216 |
+
antialias=self.interpolate_antialias,
|
| 217 |
+
**kwargs,
|
| 218 |
+
)
|
| 219 |
+
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
| 220 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 221 |
+
if with_cls_token:
|
| 222 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
| 223 |
+
else:
|
| 224 |
+
return patch_pos_embed.to(previous_dtype)
|
| 225 |
+
|
| 226 |
+
def interpolate_pos_encoding2(self, x, input_size, feature_h, feature_w):
|
| 227 |
+
previous_dtype = x.dtype
|
| 228 |
+
N = self.pos_embed.shape[1] - 1
|
| 229 |
+
|
| 230 |
+
pos_embed = self.pos_embed.float()
|
| 231 |
+
patch_pos_embed = pos_embed[:, 1:]
|
| 232 |
+
dim = x.shape[-1]
|
| 233 |
+
w0 = input_size // self.patch_size
|
| 234 |
+
h0 = input_size // self.patch_size
|
| 235 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
| 236 |
+
assert N == M * M
|
| 237 |
+
kwargs = {}
|
| 238 |
+
if self.interpolate_offset:
|
| 239 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
| 240 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
| 241 |
+
sx = float(w0 + self.interpolate_offset) / M
|
| 242 |
+
sy = float(h0 + self.interpolate_offset) / M
|
| 243 |
+
kwargs["scale_factor"] = (sx, sy)
|
| 244 |
+
else:
|
| 245 |
+
# Simply specify an output size instead of a scale factor
|
| 246 |
+
kwargs["size"] = (w0, h0)
|
| 247 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 248 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
| 249 |
+
mode="bicubic",
|
| 250 |
+
antialias=self.interpolate_antialias,
|
| 251 |
+
**kwargs,
|
| 252 |
+
)
|
| 253 |
+
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
| 254 |
+
patch_pos_embed = patch_pos_embed[...,:feature_h,:feature_w].permute(0, 2, 3, 1).reshape(1, -1, dim)
|
| 255 |
+
return patch_pos_embed.to(previous_dtype)
|
| 256 |
+
|
| 257 |
+
def interpolate_pos_encoding3(self, target_size):
|
| 258 |
+
# previous_dtype = x.dtype
|
| 259 |
+
N = self.pos_embed.shape[1] - 1
|
| 260 |
+
pos_embed = self.pos_embed.float()
|
| 261 |
+
patch_pos_embed = pos_embed[:, 1:]
|
| 262 |
+
dim = self.embed_dim
|
| 263 |
+
w0 = target_size
|
| 264 |
+
h0 = target_size
|
| 265 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
| 266 |
+
assert N == M * M
|
| 267 |
+
kwargs = {}
|
| 268 |
+
if self.interpolate_offset:
|
| 269 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
| 270 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
| 271 |
+
sx = float(w0 + self.interpolate_offset) / M
|
| 272 |
+
sy = float(h0 + self.interpolate_offset) / M
|
| 273 |
+
kwargs["scale_factor"] = (sx, sy)
|
| 274 |
+
else:
|
| 275 |
+
# Simply specify an output size instead of a scale factor
|
| 276 |
+
kwargs["size"] = (w0, h0)
|
| 277 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 278 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
| 279 |
+
mode="bicubic",
|
| 280 |
+
antialias=self.interpolate_antialias,
|
| 281 |
+
**kwargs,
|
| 282 |
+
)
|
| 283 |
+
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
| 284 |
+
patch_pos_embed = patch_pos_embed.squeeze(0).permute(1,2,0)
|
| 285 |
+
|
| 286 |
+
return patch_pos_embed
|
| 287 |
+
|
| 288 |
+
def prepare_tokens_with_masks(self, x, masks=None, with_pos_embed = True):
|
| 289 |
+
B, nc, w, h = x.shape
|
| 290 |
+
x = self.patch_embed(x)
|
| 291 |
+
if masks is not None:
|
| 292 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
| 293 |
+
|
| 294 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 295 |
+
|
| 296 |
+
if with_pos_embed:
|
| 297 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
| 298 |
+
|
| 299 |
+
if self.register_tokens is not None:
|
| 300 |
+
x = torch.cat(
|
| 301 |
+
(
|
| 302 |
+
x[:, :1],
|
| 303 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
| 304 |
+
x[:, 1:],
|
| 305 |
+
),
|
| 306 |
+
dim=1,
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
return x
|
| 310 |
+
|
| 311 |
+
def prepare_tokens_with_masks2(self, x, masks):
|
| 312 |
+
assert masks.ndim == 3
|
| 313 |
+
B, nc, w, h = x.shape
|
| 314 |
+
token_lens = masks.flatten(1).sum(1).tolist()
|
| 315 |
+
patched_x = x.view(B, 3, w//self.patch_size, self.patch_size, h//self.patch_size, self.patch_size)
|
| 316 |
+
patched_x = patched_x.permute(0, 2, 4, 1, 3, 5)
|
| 317 |
+
x = self.patch_embed.norm(self.patch_embed.proj(patched_x[masks]).flatten(1))
|
| 318 |
+
pos_embed = self.interpolate_pos_encoding(x, w, h, with_cls_token=False).repeat(B, 1, 1)
|
| 319 |
+
x = x + pos_embed[masks.flatten(1)]
|
| 320 |
+
|
| 321 |
+
cr_token = self.cls_token.view(1,-1) + self.pos_embed.float()[:,0].view(1,-1)
|
| 322 |
+
if self.register_tokens is not None:
|
| 323 |
+
cr_token = torch.cat([cr_token, self.register_tokens.view(self.num_register_tokens, -1)])
|
| 324 |
+
|
| 325 |
+
# x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 326 |
+
|
| 327 |
+
# if self.register_tokens is not None:
|
| 328 |
+
# x = torch.cat(
|
| 329 |
+
# (
|
| 330 |
+
# x[:, :1],
|
| 331 |
+
# self.register_tokens.expand(x.shape[0], -1, -1),
|
| 332 |
+
# x[:, 1:],
|
| 333 |
+
# ),
|
| 334 |
+
# dim=1,
|
| 335 |
+
# )
|
| 336 |
+
|
| 337 |
+
return x, token_lens, cr_token
|
| 338 |
+
|
| 339 |
+
def forward_features_list(self, x_list, masks_list):
|
| 340 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
| 341 |
+
for blk in self.blocks:
|
| 342 |
+
x = blk(x)
|
| 343 |
+
|
| 344 |
+
all_x = x
|
| 345 |
+
output = []
|
| 346 |
+
for x, masks in zip(all_x, masks_list):
|
| 347 |
+
x_norm = self.norm(x)
|
| 348 |
+
output.append(
|
| 349 |
+
{
|
| 350 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 351 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 352 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 353 |
+
"x_prenorm": x,
|
| 354 |
+
"masks": masks,
|
| 355 |
+
}
|
| 356 |
+
)
|
| 357 |
+
return output
|
| 358 |
+
|
| 359 |
+
def forward_features(self, x, masks=None):
|
| 360 |
+
if isinstance(x, list):
|
| 361 |
+
return self.forward_features_list(x, masks)
|
| 362 |
+
|
| 363 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
| 364 |
+
|
| 365 |
+
for blk in self.blocks:
|
| 366 |
+
x = blk(x)
|
| 367 |
+
|
| 368 |
+
x_norm = self.norm(x)
|
| 369 |
+
return {
|
| 370 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 371 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 372 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 373 |
+
"x_prenorm": x,
|
| 374 |
+
"masks": masks,
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
def forward_specific_layers(self, x, start=0, end=None, norm=True):
|
| 378 |
+
assert not self.chunked_blocks
|
| 379 |
+
if end is None:
|
| 380 |
+
end = len(self.blocks)
|
| 381 |
+
for blk in self.blocks[start:end]:
|
| 382 |
+
x = blk(x)
|
| 383 |
+
out = x[:, 1 + self.num_register_tokens :]
|
| 384 |
+
if norm:
|
| 385 |
+
out = self.norm(out)
|
| 386 |
+
return x, out
|
| 387 |
+
|
| 388 |
+
def forward_specific_layers_list(self, x_list, start=0, end=None, norm=True, get_feature=True):
|
| 389 |
+
assert not self.chunked_blocks
|
| 390 |
+
if end is None:
|
| 391 |
+
end = len(self.blocks)
|
| 392 |
+
for blk in self.blocks[start:end]:
|
| 393 |
+
x_list = blk(x_list)
|
| 394 |
+
|
| 395 |
+
if get_feature:
|
| 396 |
+
out_list = [x[:, 1 + self.num_register_tokens:, :] for x in x_list]
|
| 397 |
+
if norm:
|
| 398 |
+
out_list = [self.norm(out) for out in out_list]
|
| 399 |
+
return x_list, out_list
|
| 400 |
+
else:
|
| 401 |
+
return x_list
|
| 402 |
+
|
| 403 |
+
def forward_additional_layers_list(self, x_list, start=0, end=None, norm=True, get_feature=True):
|
| 404 |
+
assert not self.chunked_blocks
|
| 405 |
+
if end is None:
|
| 406 |
+
end = len(self.additional_blocks)
|
| 407 |
+
for blk in self.additional_blocks[start:end]:
|
| 408 |
+
x_list = blk(x_list)
|
| 409 |
+
|
| 410 |
+
if get_feature:
|
| 411 |
+
out_list = [x[:, 1 + self.num_register_tokens:, :] for x in x_list]
|
| 412 |
+
if norm:
|
| 413 |
+
out_list = [self.norm(out) for out in out_list]
|
| 414 |
+
return x_list, out_list
|
| 415 |
+
else:
|
| 416 |
+
return x_list
|
| 417 |
+
|
| 418 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
| 419 |
+
x = self.prepare_tokens_with_masks(x)
|
| 420 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 421 |
+
output, total_block_len = [], len(self.blocks)
|
| 422 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 423 |
+
for i, blk in enumerate(self.blocks):
|
| 424 |
+
x = blk(x)
|
| 425 |
+
if i in blocks_to_take:
|
| 426 |
+
output.append(x)
|
| 427 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 428 |
+
return output
|
| 429 |
+
|
| 430 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
| 431 |
+
x = self.prepare_tokens_with_masks(x)
|
| 432 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
| 433 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 434 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 435 |
+
for block_chunk in self.blocks:
|
| 436 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
| 437 |
+
x = blk(x)
|
| 438 |
+
if i in blocks_to_take:
|
| 439 |
+
output.append(x)
|
| 440 |
+
i += 1
|
| 441 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 442 |
+
return output
|
| 443 |
+
|
| 444 |
+
def get_intermediate_layers(
|
| 445 |
+
self,
|
| 446 |
+
x: torch.Tensor,
|
| 447 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
| 448 |
+
reshape: bool = False,
|
| 449 |
+
return_class_token: bool = False,
|
| 450 |
+
norm=True,
|
| 451 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 452 |
+
if self.chunked_blocks:
|
| 453 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
| 454 |
+
else:
|
| 455 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
| 456 |
+
if norm:
|
| 457 |
+
outputs = [self.norm(out) for out in outputs]
|
| 458 |
+
class_tokens = [out[:, 0] for out in outputs]
|
| 459 |
+
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
| 460 |
+
if reshape:
|
| 461 |
+
B, _, w, h = x.shape
|
| 462 |
+
outputs = [
|
| 463 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
| 464 |
+
for out in outputs
|
| 465 |
+
]
|
| 466 |
+
if return_class_token:
|
| 467 |
+
return tuple(zip(outputs, class_tokens))
|
| 468 |
+
return tuple(outputs)
|
| 469 |
+
|
| 470 |
+
def forward(self, *args, is_training=False, **kwargs):
|
| 471 |
+
ret = self.forward_features(*args, **kwargs)
|
| 472 |
+
if is_training:
|
| 473 |
+
return ret
|
| 474 |
+
else:
|
| 475 |
+
return self.head(ret["x_norm_clstoken"])
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 479 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 480 |
+
if isinstance(module, nn.Linear):
|
| 481 |
+
trunc_normal_(module.weight, std=0.02)
|
| 482 |
+
if module.bias is not None:
|
| 483 |
+
nn.init.zeros_(module.bias)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
| 487 |
+
model = DinoVisionTransformer(
|
| 488 |
+
patch_size=patch_size,
|
| 489 |
+
embed_dim=384,
|
| 490 |
+
depth=12,
|
| 491 |
+
num_heads=6,
|
| 492 |
+
mlp_ratio=4,
|
| 493 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 494 |
+
num_register_tokens=num_register_tokens,
|
| 495 |
+
**kwargs,
|
| 496 |
+
)
|
| 497 |
+
return model
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
| 501 |
+
model = DinoVisionTransformer(
|
| 502 |
+
patch_size=patch_size,
|
| 503 |
+
embed_dim=768,
|
| 504 |
+
depth=12,
|
| 505 |
+
num_heads=12,
|
| 506 |
+
mlp_ratio=4,
|
| 507 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 508 |
+
num_register_tokens=num_register_tokens,
|
| 509 |
+
**kwargs,
|
| 510 |
+
)
|
| 511 |
+
return model
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
| 515 |
+
model = DinoVisionTransformer(
|
| 516 |
+
patch_size=patch_size,
|
| 517 |
+
embed_dim=1024,
|
| 518 |
+
depth=24,
|
| 519 |
+
num_heads=16,
|
| 520 |
+
mlp_ratio=4,
|
| 521 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 522 |
+
num_register_tokens=num_register_tokens,
|
| 523 |
+
**kwargs,
|
| 524 |
+
)
|
| 525 |
+
return model
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
| 529 |
+
"""
|
| 530 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
| 531 |
+
"""
|
| 532 |
+
model = DinoVisionTransformer(
|
| 533 |
+
patch_size=patch_size,
|
| 534 |
+
embed_dim=1536,
|
| 535 |
+
depth=40,
|
| 536 |
+
num_heads=24,
|
| 537 |
+
mlp_ratio=4,
|
| 538 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 539 |
+
num_register_tokens=num_register_tokens,
|
| 540 |
+
**kwargs,
|
| 541 |
+
)
|
| 542 |
+
return model
|
models/human_models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .smpl_models import SMPL_Layer, smpl_gendered
|
models/human_models/smpl_models.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import smplx
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pickle
|
| 6 |
+
import os.path as osp
|
| 7 |
+
from configs.paths import smpl_model_path
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SMPL_Layer(nn.Module):
|
| 11 |
+
def __init__(self, model_path, with_genders = True, **kwargs):
|
| 12 |
+
"""
|
| 13 |
+
Extension of the SMPL Layer with gendered inputs.
|
| 14 |
+
"""
|
| 15 |
+
super().__init__()
|
| 16 |
+
smpl_kwargs = {'create_global_orient': False, 'create_body_pose': False,
|
| 17 |
+
'create_betas': False, 'create_transl': False}
|
| 18 |
+
smpl_kwargs.update(kwargs)
|
| 19 |
+
self.with_genders = with_genders
|
| 20 |
+
if self.with_genders:
|
| 21 |
+
self.layer_n = smplx.create(model_path, 'smpl', gender='neutral', **smpl_kwargs)
|
| 22 |
+
self.layer_m = smplx.create(model_path, 'smpl', gender='male', **smpl_kwargs)
|
| 23 |
+
self.layer_f = smplx.create(model_path, 'smpl', gender='female', **smpl_kwargs)
|
| 24 |
+
self.layers = {'neutral': self.layer_n, 'male': self.layer_m, 'female': self.layer_f}
|
| 25 |
+
else:
|
| 26 |
+
self.layer_n = smplx.create(model_path, 'smpl', gender='neutral', **smpl_kwargs)
|
| 27 |
+
self.layers = {'neutral': self.layer_n}
|
| 28 |
+
|
| 29 |
+
self.vertex_num = 6890
|
| 30 |
+
self.faces = self.layer_n.faces
|
| 31 |
+
|
| 32 |
+
self.body_vertex_idx = np.load(osp.join(model_path, 'smpl', 'body_verts_smpl.npy'))
|
| 33 |
+
self.smpl2h36m_regressor = np.load(osp.join(model_path, 'smpl', 'J_regressor_h36m_correct.npy'))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def forward_single_gender(self, poses, betas, gender='neutral'):
|
| 37 |
+
bs = poses.shape[0]
|
| 38 |
+
if poses.ndim == 2:
|
| 39 |
+
poses = poses.view(bs, -1, 3)
|
| 40 |
+
|
| 41 |
+
assert poses.shape[1] == 24
|
| 42 |
+
pose_params = {'global_orient': poses[:, :1, :],
|
| 43 |
+
'body_pose': poses[:, 1:, :]}
|
| 44 |
+
|
| 45 |
+
smpl_output = self.layers[gender](betas=betas, **pose_params)
|
| 46 |
+
return smpl_output.vertices, smpl_output.joints
|
| 47 |
+
|
| 48 |
+
def forward(self, poses, betas, genders = None):
|
| 49 |
+
bs = poses.shape[0]
|
| 50 |
+
assert poses.shape[0] == betas.shape[0]
|
| 51 |
+
if genders is None:
|
| 52 |
+
return self.forward_single_gender(poses, betas)
|
| 53 |
+
else:
|
| 54 |
+
assert len(genders) == bs
|
| 55 |
+
assert set(genders) <= {'male', 'female'}
|
| 56 |
+
assert self.with_genders
|
| 57 |
+
|
| 58 |
+
male_idx = [i for i, gender in enumerate(genders) if gender == 'male']
|
| 59 |
+
if len(male_idx) == bs:
|
| 60 |
+
return self.forward_single_gender(poses, betas, gender='male')
|
| 61 |
+
elif len(male_idx) == 0:
|
| 62 |
+
return self.forward_single_gender(poses, betas, gender='female')
|
| 63 |
+
else:
|
| 64 |
+
vertices, joints = self.forward_single_gender(poses, betas, gender='female')
|
| 65 |
+
vertices[male_idx], joints[male_idx] =\
|
| 66 |
+
self.forward_single_gender(poses[male_idx], betas[male_idx], gender='male')
|
| 67 |
+
return vertices, joints
|
| 68 |
+
|
| 69 |
+
smpl_gendered = SMPL_Layer(smpl_model_path, with_genders = True)
|
models/matcher.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from DAB-DETR (https://github.com/IDEA-Research/DAB-DETR)
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from scipy.optimize import linear_sum_assignment
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class HungarianMatcher(nn.Module):
|
| 11 |
+
"""This class computes an assignment between the targets and the predictions of the network
|
| 12 |
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
| 13 |
+
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
| 14 |
+
while the others are un-matched (and thus treated as non-objects).
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self,
|
| 18 |
+
cost_conf: float = 1,
|
| 19 |
+
cost_bbox: float = 1,
|
| 20 |
+
cost_giou: float = 1,
|
| 21 |
+
cost_kpts: float = 10,
|
| 22 |
+
j2ds_norm_scale: float = 518,
|
| 23 |
+
):
|
| 24 |
+
"""Creates the matcher
|
| 25 |
+
Params:
|
| 26 |
+
cost_class: This is the relative weight of the classification error in the matching cost
|
| 27 |
+
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
|
| 28 |
+
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
|
| 29 |
+
"""
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.cost_conf = cost_conf
|
| 32 |
+
self.cost_bbox = cost_bbox
|
| 33 |
+
self.cost_giou = cost_giou
|
| 34 |
+
self.cost_kpts = cost_kpts
|
| 35 |
+
self.j2ds_norm_scale = j2ds_norm_scale
|
| 36 |
+
assert cost_conf != 0 or cost_bbox != 0 or cost_giou != 0 or cost_kpts != 0, "all costs cant be 0"
|
| 37 |
+
|
| 38 |
+
# self.focal_alpha = focal_alpha
|
| 39 |
+
|
| 40 |
+
@torch.no_grad()
|
| 41 |
+
def forward_enc(self, outputs, targets):
|
| 42 |
+
""" Performs the matching
|
| 43 |
+
Params:
|
| 44 |
+
outputs: This is a dict that contains at least these entries:
|
| 45 |
+
"pred_confs": Tensor of flattened confidence score, [total_lens, 1]
|
| 46 |
+
"pred_boxes": Tensor of flattened boxes, [total_lens, 4]
|
| 47 |
+
"lens": num of predictions for each sample in the batch, sum(lens) == total_lens
|
| 48 |
+
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
| 49 |
+
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
|
| 50 |
+
Returns:
|
| 51 |
+
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
| 52 |
+
- index_i is the indices of the selected predictions (in order)
|
| 53 |
+
- index_j is the indices of the corresponding selected targets (in order)
|
| 54 |
+
For each batch element, it holds:
|
| 55 |
+
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
| 56 |
+
"""
|
| 57 |
+
out_conf = outputs['pred_confs']
|
| 58 |
+
out_bbox = outputs["pred_boxes"]
|
| 59 |
+
lens = outputs['lens']
|
| 60 |
+
assert len(lens) == len(targets)
|
| 61 |
+
assert tuple(out_conf.shape) == (sum(lens),1)
|
| 62 |
+
assert tuple(out_bbox.shape) == (sum(lens),4)
|
| 63 |
+
|
| 64 |
+
# Also concat the target labels and boxes
|
| 65 |
+
tgt_bbox = torch.cat([v["boxes"] for v in targets])
|
| 66 |
+
|
| 67 |
+
# Compute the confidence cost.
|
| 68 |
+
alpha = 0.25
|
| 69 |
+
gamma = 2.0
|
| 70 |
+
cost_conf = alpha * ((1 - out_conf) ** gamma) * (-(out_conf + 1e-8).log())
|
| 71 |
+
# cost_conf = -(out_conf+1e-8).log()
|
| 72 |
+
|
| 73 |
+
# Compute the L1 cost between boxes
|
| 74 |
+
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
|
| 75 |
+
|
| 76 |
+
# Compute the giou cost betwen boxes
|
| 77 |
+
# import ipdb; ipdb.set_trace()
|
| 78 |
+
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
|
| 79 |
+
|
| 80 |
+
# Final cost matrix
|
| 81 |
+
C = self.cost_conf*cost_conf + self.cost_bbox * cost_bbox + self.cost_giou * cost_giou
|
| 82 |
+
C = C.cpu()
|
| 83 |
+
|
| 84 |
+
sizes = [len(v["boxes"]) for v in targets]
|
| 85 |
+
idx=0
|
| 86 |
+
indices = []
|
| 87 |
+
for i, c in enumerate(C.split(sizes, -1)):
|
| 88 |
+
indices.append(linear_sum_assignment(c[idx:idx+lens[i]]))
|
| 89 |
+
idx += lens[i]
|
| 90 |
+
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
| 91 |
+
|
| 92 |
+
@torch.no_grad()
|
| 93 |
+
def forward(self, outputs, targets):
|
| 94 |
+
""" Performs the matching
|
| 95 |
+
Params:
|
| 96 |
+
outputs: This is a dict that contains at least these entries:
|
| 97 |
+
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
| 98 |
+
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
|
| 99 |
+
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
| 100 |
+
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
|
| 101 |
+
objects in the target) containing the class labels
|
| 102 |
+
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
|
| 103 |
+
Returns:
|
| 104 |
+
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
| 105 |
+
- index_i is the indices of the selected predictions (in order)
|
| 106 |
+
- index_j is the indices of the corresponding selected targets (in order)
|
| 107 |
+
For each batch element, it holds:
|
| 108 |
+
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
| 109 |
+
"""
|
| 110 |
+
assert outputs['pred_confs'].shape[0]==len(targets)
|
| 111 |
+
bs, num_queries, _ = outputs["pred_confs"].shape
|
| 112 |
+
|
| 113 |
+
# We flatten to compute the cost matrices in a batch
|
| 114 |
+
out_conf = outputs['pred_confs'].flatten(0,1) # [batch_size * num_queries, 1]
|
| 115 |
+
out_bbox = outputs["pred_boxes"].flatten(0,1) # [batch_size * num_queries, 4]
|
| 116 |
+
out_kpts = outputs['pred_j2ds'][...,:22,:].flatten(2).flatten(0,1) / self.j2ds_norm_scale
|
| 117 |
+
|
| 118 |
+
# Also concat the target labels and boxes
|
| 119 |
+
tgt_bbox = torch.cat([v["boxes"] for v in targets])
|
| 120 |
+
tgt_kpts = torch.cat([v['j2ds'][:,:22,:].flatten(1) for v in targets]) / self.j2ds_norm_scale
|
| 121 |
+
tgt_kpts_mask = torch.cat([v['j2ds_mask'][:,:22,:].flatten(1) for v in targets])
|
| 122 |
+
tgt_kpts_vis_cnt = tgt_kpts_mask.sum(-1)
|
| 123 |
+
assert (torch.all(tgt_kpts_vis_cnt))
|
| 124 |
+
|
| 125 |
+
# Compute the confidence cost.
|
| 126 |
+
alpha = 0.25
|
| 127 |
+
gamma = 2.0
|
| 128 |
+
cost_conf = alpha * ((1 - out_conf) ** gamma) * (-(out_conf + 1e-8).log())
|
| 129 |
+
# cost_conf = -(out_conf+1e-8).log()
|
| 130 |
+
|
| 131 |
+
# Compute the L1 cost between boxes
|
| 132 |
+
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
|
| 133 |
+
|
| 134 |
+
# Compute the giou cost betwen boxes
|
| 135 |
+
# import ipdb; ipdb.set_trace()
|
| 136 |
+
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
|
| 137 |
+
|
| 138 |
+
# Compute the mean L1 cost between visible joints
|
| 139 |
+
all_dist = torch.abs(out_kpts[:,None,:] - tgt_kpts[None,:,:])
|
| 140 |
+
mean_dist = (all_dist * tgt_kpts_mask[None,:,:]).sum(-1) / tgt_kpts_vis_cnt[None,:]
|
| 141 |
+
cost_kpts = mean_dist
|
| 142 |
+
|
| 143 |
+
# Final cost matrix
|
| 144 |
+
C = self.cost_conf*cost_conf + self.cost_kpts*cost_kpts + self.cost_bbox * cost_bbox + self.cost_giou * cost_giou
|
| 145 |
+
C = C.view(bs, num_queries, -1).cpu()
|
| 146 |
+
|
| 147 |
+
sizes = [len(v["boxes"]) for v in targets]
|
| 148 |
+
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
|
| 149 |
+
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def build_matcher(args):
|
| 153 |
+
return HungarianMatcher(
|
| 154 |
+
cost_conf=args.set_cost_conf,
|
| 155 |
+
cost_bbox=args.set_cost_bbox,
|
| 156 |
+
cost_giou=args.set_cost_giou,
|
| 157 |
+
cost_kpts=args.set_cost_kpts,
|
| 158 |
+
j2ds_norm_scale=args.input_size
|
| 159 |
+
)
|
models/position_encoding.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from DAB-DETR (https://github.com/IDEA-Research/DAB-DETR)
|
| 2 |
+
"""
|
| 3 |
+
Various positional encodings for the transformer.
|
| 4 |
+
"""
|
| 5 |
+
import math
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from utils.misc import NestedTensor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def position_encoding_xy(pos_x, pos_y, embedding_dim, temperature = 20, scale = 2*math.pi):
|
| 13 |
+
assert embedding_dim % 2 == 0
|
| 14 |
+
assert pos_x.ndim == 1 and pos_y.ndim == 1
|
| 15 |
+
dim_t = torch.arange(embedding_dim // 2, dtype=torch.float32, device=pos_x.device)
|
| 16 |
+
dim_t = temperature ** (2 * (dim_t // 2) / (embedding_dim // 2))
|
| 17 |
+
x_embed = pos_x * scale
|
| 18 |
+
y_embed = pos_y * scale
|
| 19 |
+
pos_x = x_embed[:, None] / dim_t
|
| 20 |
+
pos_y = y_embed[:, None] / dim_t
|
| 21 |
+
pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
|
| 22 |
+
pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
|
| 23 |
+
pos = torch.cat([pos_y,pos_x], dim=1)
|
| 24 |
+
return pos
|
| 25 |
+
|
| 26 |
+
class PositionEmbeddingSine(nn.Module):
|
| 27 |
+
"""
|
| 28 |
+
This is a more standard version of the position embedding, very similar to the one
|
| 29 |
+
used by the Attention is all you need paper, generalized to work on images.
|
| 30 |
+
"""
|
| 31 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.num_pos_feats = num_pos_feats
|
| 34 |
+
self.temperature = temperature
|
| 35 |
+
self.normalize = normalize
|
| 36 |
+
if scale is not None and normalize is False:
|
| 37 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 38 |
+
if scale is None:
|
| 39 |
+
scale = 2 * math.pi
|
| 40 |
+
self.scale = scale
|
| 41 |
+
|
| 42 |
+
def forward(self, tensor_list: NestedTensor):
|
| 43 |
+
x = tensor_list.tensors
|
| 44 |
+
mask = tensor_list.mask
|
| 45 |
+
assert mask is not None
|
| 46 |
+
not_mask = ~mask
|
| 47 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
| 48 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
| 49 |
+
if self.normalize:
|
| 50 |
+
eps = 1e-6
|
| 51 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
| 52 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
| 53 |
+
|
| 54 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 55 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
| 56 |
+
|
| 57 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 58 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 59 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 60 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 61 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 62 |
+
return pos
|
| 63 |
+
|
| 64 |
+
class PositionEmbeddingSineHW(nn.Module):
|
| 65 |
+
"""
|
| 66 |
+
This is a more standard version of the position embedding, very similar to the one
|
| 67 |
+
used by the Attention is all you need paper, generalized to work on images.
|
| 68 |
+
"""
|
| 69 |
+
def __init__(self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.num_pos_feats = num_pos_feats
|
| 72 |
+
self.temperatureH = temperatureH
|
| 73 |
+
self.temperatureW = temperatureW
|
| 74 |
+
self.normalize = normalize
|
| 75 |
+
if scale is not None and normalize is False:
|
| 76 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 77 |
+
if scale is None:
|
| 78 |
+
scale = 2 * math.pi
|
| 79 |
+
self.scale = scale
|
| 80 |
+
|
| 81 |
+
def forward(self, tensor_list: NestedTensor):
|
| 82 |
+
x = tensor_list.tensors
|
| 83 |
+
mask = tensor_list.mask
|
| 84 |
+
assert mask is not None
|
| 85 |
+
not_mask = ~mask
|
| 86 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
| 87 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
| 88 |
+
|
| 89 |
+
# import ipdb; ipdb.set_trace()
|
| 90 |
+
|
| 91 |
+
if self.normalize:
|
| 92 |
+
eps = 1e-6
|
| 93 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
| 94 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
| 95 |
+
|
| 96 |
+
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 97 |
+
dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats)
|
| 98 |
+
pos_x = x_embed[:, :, :, None] / dim_tx
|
| 99 |
+
|
| 100 |
+
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 101 |
+
dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats)
|
| 102 |
+
pos_y = y_embed[:, :, :, None] / dim_ty
|
| 103 |
+
|
| 104 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 105 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 106 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 107 |
+
|
| 108 |
+
# import ipdb; ipdb.set_trace()
|
| 109 |
+
|
| 110 |
+
return pos
|
| 111 |
+
|
| 112 |
+
class PositionEmbeddingLearned(nn.Module):
|
| 113 |
+
"""
|
| 114 |
+
Absolute pos embedding, learned.
|
| 115 |
+
"""
|
| 116 |
+
def __init__(self, num_pos_feats=256):
|
| 117 |
+
super().__init__()
|
| 118 |
+
self.row_embed = nn.Embedding(50, num_pos_feats)
|
| 119 |
+
self.col_embed = nn.Embedding(50, num_pos_feats)
|
| 120 |
+
self.reset_parameters()
|
| 121 |
+
|
| 122 |
+
def reset_parameters(self):
|
| 123 |
+
nn.init.uniform_(self.row_embed.weight)
|
| 124 |
+
nn.init.uniform_(self.col_embed.weight)
|
| 125 |
+
|
| 126 |
+
def forward(self, tensor_list: NestedTensor):
|
| 127 |
+
x = tensor_list.tensors
|
| 128 |
+
h, w = x.shape[-2:]
|
| 129 |
+
i = torch.arange(w, device=x.device)
|
| 130 |
+
j = torch.arange(h, device=x.device)
|
| 131 |
+
x_emb = self.col_embed(i)
|
| 132 |
+
y_emb = self.row_embed(j)
|
| 133 |
+
pos = torch.cat([
|
| 134 |
+
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
| 135 |
+
y_emb.unsqueeze(1).repeat(1, w, 1),
|
| 136 |
+
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
| 137 |
+
return pos
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def build_position_encoding(args):
|
| 141 |
+
N_steps = args.hidden_dim // 2
|
| 142 |
+
if args.position_embedding in ('v2', 'sine'):
|
| 143 |
+
# TODO find a better way of exposing other arguments
|
| 144 |
+
position_embedding = PositionEmbeddingSineHW(
|
| 145 |
+
N_steps,
|
| 146 |
+
temperatureH=args.pe_temperatureH,
|
| 147 |
+
temperatureW=args.pe_temperatureW,
|
| 148 |
+
normalize=True
|
| 149 |
+
)
|
| 150 |
+
elif args.position_embedding in ('v3', 'learned'):
|
| 151 |
+
position_embedding = PositionEmbeddingLearned(N_steps)
|
| 152 |
+
else:
|
| 153 |
+
raise ValueError(f"not supported {args.position_embedding}")
|
| 154 |
+
|
| 155 |
+
return position_embedding
|
models/sat_model.py
ADDED
|
@@ -0,0 +1,767 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from DAB-DETR (https://github.com/IDEA-Research/DAB-DETR)
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from math import tan,pi
|
| 7 |
+
from typing import Dict
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch import nn
|
| 11 |
+
from torchvision.transforms import Resize
|
| 12 |
+
import numpy as np
|
| 13 |
+
import time
|
| 14 |
+
import random
|
| 15 |
+
|
| 16 |
+
from utils.misc import (NestedTensor, nested_tensor_from_tensor_list,
|
| 17 |
+
accuracy, get_world_size, interpolate,
|
| 18 |
+
is_dist_avail_and_initialized, inverse_sigmoid)
|
| 19 |
+
|
| 20 |
+
from utils.transforms import rot6d_to_axis_angle, img2patch_flat, img2patch, to_zorder
|
| 21 |
+
from utils.map import build_z_map
|
| 22 |
+
from utils import constants
|
| 23 |
+
from configs.paths import smpl_mean_path
|
| 24 |
+
|
| 25 |
+
from models.encoders import build_encoder
|
| 26 |
+
from .matcher import build_matcher
|
| 27 |
+
from .decoder import build_decoder
|
| 28 |
+
from .position_encoding import position_encoding_xy
|
| 29 |
+
from .criterion import SetCriterion
|
| 30 |
+
from .dn_components import prepare_for_cdn, dn_post_process
|
| 31 |
+
import copy
|
| 32 |
+
|
| 33 |
+
from configs.paths import smpl_model_path
|
| 34 |
+
from models.human_models import SMPL_Layer
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _get_clones(module, N):
|
| 38 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
| 39 |
+
|
| 40 |
+
class Model(nn.Module):
|
| 41 |
+
""" One-stage Multi-person Human Mesh Estimation via Scale-adaptive Tokens """
|
| 42 |
+
def __init__(self, encoder, decoder,
|
| 43 |
+
num_queries,
|
| 44 |
+
input_size,
|
| 45 |
+
sat_cfg = {'use_sat': False},
|
| 46 |
+
dn_cfg = {'use_dn': False},
|
| 47 |
+
train_pos_embed = True,
|
| 48 |
+
aux_loss=True,
|
| 49 |
+
iter_update=True,
|
| 50 |
+
query_dim=4,
|
| 51 |
+
bbox_embed_diff_each_layer=True,
|
| 52 |
+
random_refpoints_xy=False,
|
| 53 |
+
num_poses=24,
|
| 54 |
+
dim_shape=10,
|
| 55 |
+
FOV=pi/3
|
| 56 |
+
):
|
| 57 |
+
""" Initializes the model.
|
| 58 |
+
Parameters:
|
| 59 |
+
encoder: torch module of the encoder to be used. See ./encoders.
|
| 60 |
+
decoder: torch module of the decoder architecture. See decoder.py
|
| 61 |
+
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
| 62 |
+
DETR can detect in a single image.
|
| 63 |
+
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
| 64 |
+
iter_update: iterative update of boxes
|
| 65 |
+
query_dim: query dimension. 2 for point and 4 for box.
|
| 66 |
+
bbox_embed_diff_each_layer: dont share weights of prediction heads. Default for False. (shared weights.)
|
| 67 |
+
random_refpoints_xy: random init the x,y of anchor boxes and freeze them. (It sometimes helps to improve the performance)
|
| 68 |
+
"""
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
# ========== Start of common settings =============
|
| 72 |
+
self.input_size = input_size
|
| 73 |
+
hidden_dim = decoder.d_model
|
| 74 |
+
num_dec_layers = decoder.dec_layers
|
| 75 |
+
self.hidden_dim = hidden_dim
|
| 76 |
+
# camera model
|
| 77 |
+
self.focal = input_size/(2*tan(FOV/2))
|
| 78 |
+
self.FOV = FOV
|
| 79 |
+
cam_intrinsics = torch.tensor([[self.focal,0.,self.input_size/2],
|
| 80 |
+
[0.,self.focal,self.input_size/2],
|
| 81 |
+
[0.,0.,1.]])
|
| 82 |
+
self.register_buffer('cam_intrinsics', cam_intrinsics)
|
| 83 |
+
# human model
|
| 84 |
+
self.num_poses = num_poses
|
| 85 |
+
self.dim_shape = dim_shape
|
| 86 |
+
self.human_model = SMPL_Layer(model_path = smpl_model_path, with_genders = False)
|
| 87 |
+
# init params (following multi-hmr)
|
| 88 |
+
smpl_mean_params = np.load(smpl_mean_path, allow_pickle = True)
|
| 89 |
+
self.register_buffer('mean_pose', torch.from_numpy(smpl_mean_params['pose']))
|
| 90 |
+
self.register_buffer('mean_shape', torch.from_numpy(smpl_mean_params['shape']))
|
| 91 |
+
# ========== End of common settings =============
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# ========== Start of SAT-encoder settings =============
|
| 95 |
+
self.encoder = encoder
|
| 96 |
+
|
| 97 |
+
self.patch_size = encoder.patch_size
|
| 98 |
+
assert self.patch_size == 14
|
| 99 |
+
|
| 100 |
+
self.use_sat = sat_cfg['use_sat']
|
| 101 |
+
self.sat_cfg = sat_cfg
|
| 102 |
+
|
| 103 |
+
if self.use_sat:
|
| 104 |
+
assert sat_cfg['num_lvls'] >= 2
|
| 105 |
+
assert self.input_size % (self.patch_size<<2) == 0
|
| 106 |
+
|
| 107 |
+
self.feature_size = []
|
| 108 |
+
for lvl in range(sat_cfg['num_lvls']):
|
| 109 |
+
patch_size = self.patch_size<<lvl
|
| 110 |
+
self.feature_size.append(self.input_size / patch_size)
|
| 111 |
+
|
| 112 |
+
# build z_order curve
|
| 113 |
+
z_depth = math.ceil(math.log2(self.feature_size[1]))
|
| 114 |
+
z_map, ys, xs = build_z_map(z_depth)
|
| 115 |
+
self.register_buffer('z_order_map', z_map)
|
| 116 |
+
self.register_buffer('y_coords', ys)
|
| 117 |
+
self.register_buffer('x_coords', xs)
|
| 118 |
+
|
| 119 |
+
self.enc_inter_norm = copy.deepcopy(encoder.norm)
|
| 120 |
+
self.scale_head = MLP(encoder.embed_dim, encoder.embed_dim, 2, 4)
|
| 121 |
+
self.encoder_patch_proj = _get_clones(encoder.patch_embed.proj, 2)
|
| 122 |
+
self.encoder_patch_norm = _get_clones(encoder.patch_embed.norm, 2)
|
| 123 |
+
|
| 124 |
+
if sat_cfg['lvl_embed']:
|
| 125 |
+
# same as level_embed in Deformable-DETR
|
| 126 |
+
self.level_embed = nn.Parameter(torch.Tensor(sat_cfg['num_lvls'],hidden_dim))
|
| 127 |
+
nn.init.normal_(self.level_embed)
|
| 128 |
+
else:
|
| 129 |
+
assert self.input_size % self.patch_size == 0
|
| 130 |
+
self.feature_size = [self.input_size // self.patch_size]
|
| 131 |
+
self.encoder_patch_proj = copy.deepcopy(encoder.patch_embed.proj)
|
| 132 |
+
self.encoder_patch_norm = copy.deepcopy(encoder.patch_embed.norm)
|
| 133 |
+
|
| 134 |
+
# cls_token and register tokens
|
| 135 |
+
encoder_cr_token = self.encoder.cls_token.view(1,-1) + self.encoder.pos_embed.float()[:,0].view(1,-1)
|
| 136 |
+
if self.encoder.register_tokens is not None:
|
| 137 |
+
encoder_cr_token = torch.cat([encoder_cr_token, self.encoder.register_tokens.view(self.encoder.num_register_tokens,-1)], dim=0)
|
| 138 |
+
self.encoder_cr_token = nn.Parameter(encoder_cr_token)
|
| 139 |
+
|
| 140 |
+
self.encoder_pos_embeds = nn.Parameter(self.encoder.interpolate_pos_encoding3(self.feature_size[0]).detach())
|
| 141 |
+
if not train_pos_embed:
|
| 142 |
+
self.encoder_pos_embeds.requires_grad = False
|
| 143 |
+
|
| 144 |
+
self.preprocessed_pos_lvl1 = None
|
| 145 |
+
|
| 146 |
+
# delete unwanted params
|
| 147 |
+
del(self.encoder.mask_token)
|
| 148 |
+
del(self.encoder.pos_embed)
|
| 149 |
+
del(self.encoder.patch_embed)
|
| 150 |
+
del(self.encoder.cls_token)
|
| 151 |
+
del(self.encoder.register_tokens)
|
| 152 |
+
# ========== End of SAT-encoder settings =============
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# ========== Start of decoder settings =============
|
| 157 |
+
self.num_queries = num_queries
|
| 158 |
+
self.decoder = decoder
|
| 159 |
+
|
| 160 |
+
# embed_dim between encoder and decoder can be different
|
| 161 |
+
self.feature_proj = nn.Linear(encoder.embed_dim, hidden_dim)
|
| 162 |
+
|
| 163 |
+
# bbox
|
| 164 |
+
self.bbox_embed_diff_each_layer = bbox_embed_diff_each_layer
|
| 165 |
+
if bbox_embed_diff_each_layer:
|
| 166 |
+
self.bbox_embed = nn.ModuleList([MLP(hidden_dim, hidden_dim, 4, 3) for i in range(num_dec_layers)])
|
| 167 |
+
else:
|
| 168 |
+
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
|
| 169 |
+
# poses (use 6D rotation)
|
| 170 |
+
self.pose_head = MLP(hidden_dim, hidden_dim, num_poses*6, 6)
|
| 171 |
+
# shape
|
| 172 |
+
self.shape_head = MLP(hidden_dim, hidden_dim, dim_shape, 5)
|
| 173 |
+
# cam_trans
|
| 174 |
+
self.cam_head = MLP(hidden_dim, hidden_dim//2, 3, 3)
|
| 175 |
+
# confidence score
|
| 176 |
+
self.conf_head = nn.Linear(hidden_dim, 1)
|
| 177 |
+
# init prior_prob setting for focal loss
|
| 178 |
+
prior_prob = 0.01
|
| 179 |
+
bias_value = -math.log((1 - prior_prob) / prior_prob)
|
| 180 |
+
self.conf_head.bias.data = torch.ones(1) * bias_value
|
| 181 |
+
|
| 182 |
+
# for iter update
|
| 183 |
+
self.pose_head = _get_clones(self.pose_head, num_dec_layers)
|
| 184 |
+
self.shape_head = _get_clones(self.shape_head, num_dec_layers)
|
| 185 |
+
|
| 186 |
+
# setting query dim (bboxes as queries)
|
| 187 |
+
self.query_dim = query_dim
|
| 188 |
+
assert query_dim == 4
|
| 189 |
+
self.refpoint_embed = nn.Embedding(num_queries, query_dim)
|
| 190 |
+
self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
|
| 191 |
+
|
| 192 |
+
self.random_refpoints_xy = random_refpoints_xy
|
| 193 |
+
if random_refpoints_xy:
|
| 194 |
+
# import ipdb; ipdb.set_trace()
|
| 195 |
+
self.refpoint_embed.weight.data[:, :2].uniform_(0,1)
|
| 196 |
+
self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(self.refpoint_embed.weight.data[:, :2])
|
| 197 |
+
self.refpoint_embed.weight.data[:, :2].requires_grad = False
|
| 198 |
+
|
| 199 |
+
self.aux_loss = aux_loss
|
| 200 |
+
self.iter_update = iter_update
|
| 201 |
+
assert iter_update
|
| 202 |
+
if self.iter_update:
|
| 203 |
+
self.decoder.decoder.bbox_embed = self.bbox_embed
|
| 204 |
+
|
| 205 |
+
assert bbox_embed_diff_each_layer
|
| 206 |
+
if bbox_embed_diff_each_layer:
|
| 207 |
+
for bbox_embed in self.bbox_embed:
|
| 208 |
+
nn.init.constant_(bbox_embed.layers[-1].weight.data, 0)
|
| 209 |
+
nn.init.constant_(bbox_embed.layers[-1].bias.data, 0)
|
| 210 |
+
else:
|
| 211 |
+
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
|
| 212 |
+
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
|
| 213 |
+
# ========== End of decoder settings =============
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# for dn training
|
| 217 |
+
self.use_dn = dn_cfg['use_dn']
|
| 218 |
+
self.dn_cfg = dn_cfg
|
| 219 |
+
if self.use_dn:
|
| 220 |
+
assert dn_cfg['dn_number'] > 0
|
| 221 |
+
if dn_cfg['tgt_embed_type'] == 'labels':
|
| 222 |
+
self.dn_enc = nn.Embedding(dn_cfg['dn_labelbook_size'], hidden_dim)
|
| 223 |
+
elif dn_cfg['tgt_embed_type'] == 'params':
|
| 224 |
+
self.dn_enc = nn.Linear(num_poses*3 + dim_shape, hidden_dim)
|
| 225 |
+
else:
|
| 226 |
+
raise NotImplementedError
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def lvl_pooling(self, tokens):
|
| 230 |
+
assert len(tokens)%4 == 0
|
| 231 |
+
C = tokens.shape[-1]
|
| 232 |
+
return torch.max(tokens.view(-1, 4, C), dim=1)[0]
|
| 233 |
+
|
| 234 |
+
def get_scale_map(self, x_list):
|
| 235 |
+
if self.sat_cfg['use_additional_blocks']:
|
| 236 |
+
x_list = self.encoder.forward_additional_layers_list(x_list, end=self.sat_cfg['get_map_layer'], get_feature=False)
|
| 237 |
+
else:
|
| 238 |
+
x_list = self.encoder.forward_specific_layers_list(x_list, end=self.sat_cfg['get_map_layer'], get_feature=False)
|
| 239 |
+
|
| 240 |
+
cr_token_list = [x[:, :1 + self.encoder.num_register_tokens, :].squeeze(0) for x in x_list]
|
| 241 |
+
x_tokens = torch.cat([x[:, 1 + self.encoder.num_register_tokens:, :].squeeze(0) for x in x_list], dim=0)
|
| 242 |
+
scale_map = self.scale_head(self.enc_inter_norm(x_tokens)).sigmoid()
|
| 243 |
+
return scale_map, cr_token_list, x_tokens
|
| 244 |
+
|
| 245 |
+
def pad_mask(self, mask):
|
| 246 |
+
mask = mask.reshape(-1,4)
|
| 247 |
+
mask[torch.any(mask, dim=1)] = True
|
| 248 |
+
return mask.flatten()
|
| 249 |
+
|
| 250 |
+
def forward_encoder(self, samples, targets, use_gt = False):
|
| 251 |
+
B = len(samples)
|
| 252 |
+
C = self.encoder.embed_dim
|
| 253 |
+
cr_token_list = [self.encoder_cr_token]*len(samples)
|
| 254 |
+
|
| 255 |
+
if not self.use_sat:
|
| 256 |
+
# img2token
|
| 257 |
+
lvl0_feature_hw = [(img.shape[1]//self.patch_size, img.shape[2]//self.patch_size) for img in samples]
|
| 258 |
+
lvl0_token_lens = [h*w for (h,w) in lvl0_feature_hw]
|
| 259 |
+
lvl0_img_patches = torch.cat([img2patch_flat(img, patch_size = self.patch_size)\
|
| 260 |
+
for img in samples], dim=0)
|
| 261 |
+
lvl0_tokens = self.encoder_patch_norm(self.encoder_patch_proj(lvl0_img_patches).flatten(1))
|
| 262 |
+
|
| 263 |
+
# token position information
|
| 264 |
+
full_grids = torch.meshgrid(torch.arange(self.feature_size[0]), torch.arange(self.feature_size[0]), indexing='ij')
|
| 265 |
+
lvl0_pos_y = torch.cat([full_grids[0][:h,:w].flatten() for (h,w) in lvl0_feature_hw]).to(device = lvl0_tokens.device)
|
| 266 |
+
lvl0_pos_x = torch.cat([full_grids[1][:h,:w].flatten() for (h,w) in lvl0_feature_hw]).to(device = lvl0_tokens.device)
|
| 267 |
+
|
| 268 |
+
# pos_embed
|
| 269 |
+
full_pos_embed = self.encoder_pos_embeds
|
| 270 |
+
lvl0_pos_embed = torch.cat([full_pos_embed[:h,:w].flatten(0,1)\
|
| 271 |
+
for (h,w) in lvl0_feature_hw], dim=0)
|
| 272 |
+
lvl0_tokens = lvl0_tokens + lvl0_pos_embed
|
| 273 |
+
|
| 274 |
+
# convert to list for DINOv2 input
|
| 275 |
+
x_list = [torch.cat([cr, lvl0],dim=0).unsqueeze(0)\
|
| 276 |
+
for (cr, lvl0) \
|
| 277 |
+
in zip(cr_token_list, lvl0_tokens.split(lvl0_token_lens))]
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
lvl0_pos_y_norm = (lvl0_pos_y.to(dtype=lvl0_tokens.dtype) + 0.5) / self.feature_size[0]
|
| 281 |
+
lvl0_pos_x_norm = (lvl0_pos_x.to(dtype=lvl0_tokens.dtype) + 0.5) / self.feature_size[0]
|
| 282 |
+
pos_x_list = list(lvl0_pos_y_norm.split(lvl0_token_lens))
|
| 283 |
+
pos_y_list = list(lvl0_pos_x_norm.split(lvl0_token_lens))
|
| 284 |
+
scale_map_dict = None
|
| 285 |
+
# also create lvl_list for patch visualization
|
| 286 |
+
lvl_list = [torch.zeros_like(pos,dtype=int) for pos in pos_x_list]
|
| 287 |
+
|
| 288 |
+
else:
|
| 289 |
+
lvl1_feature_hw = [(img.shape[1]//(2*self.patch_size), img.shape[2]//(2*self.patch_size)) for img in samples]
|
| 290 |
+
lvl1_token_lens = [h*w for (h,w) in lvl1_feature_hw]
|
| 291 |
+
|
| 292 |
+
lvl1_img_patches_28, lvl1_zorders = [], []
|
| 293 |
+
lvl1_pos_y, lvl1_pos_x = [], []
|
| 294 |
+
lvl1_bids = []
|
| 295 |
+
|
| 296 |
+
for i, img in enumerate(samples):
|
| 297 |
+
z_patches, z_order, pos_y, pos_x = to_zorder(img2patch(img, patch_size = 2*self.patch_size),
|
| 298 |
+
z_order_map = self.z_order_map,
|
| 299 |
+
y_coords = self.y_coords,
|
| 300 |
+
x_coords = self.x_coords)
|
| 301 |
+
|
| 302 |
+
lvl1_img_patches_28.append(z_patches)
|
| 303 |
+
|
| 304 |
+
lvl1_zorders.append(z_order)
|
| 305 |
+
lvl1_pos_y.append(pos_y)
|
| 306 |
+
lvl1_pos_x.append(pos_x)
|
| 307 |
+
lvl1_bids.append(torch.full_like(pos_y, i, dtype=torch.int64))
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
lvl1_img_patches_28 = torch.cat(lvl1_img_patches_28, dim=0)
|
| 312 |
+
lvl1_zorders = torch.cat(lvl1_zorders, dim=0)
|
| 313 |
+
lvl1_pos_y = torch.cat(lvl1_pos_y, dim=0)
|
| 314 |
+
lvl1_pos_x = torch.cat(lvl1_pos_x, dim=0)
|
| 315 |
+
lvl1_bids = torch.cat(lvl1_bids, dim=0)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
# (L1, 3, 28, 28)
|
| 320 |
+
assert len(lvl1_img_patches_28) == sum(lvl1_token_lens)
|
| 321 |
+
lvl1_img_patches = F.interpolate(lvl1_img_patches_28, size = (14,14), mode='bilinear', align_corners=False)
|
| 322 |
+
# (L1, 3, 14, 14)
|
| 323 |
+
lvl1_tokens = self.encoder_patch_norm[1](self.encoder_patch_proj[1](lvl1_img_patches).flatten(1))
|
| 324 |
+
# (L1, C)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
assert len(lvl1_pos_y) == len(lvl1_tokens)
|
| 329 |
+
full_pos_embed = self.preprocessed_pos_lvl1 if not self.training\
|
| 330 |
+
else F.interpolate(self.encoder_pos_embeds.unsqueeze(0).permute(0, 3, 1, 2),
|
| 331 |
+
mode="bicubic",
|
| 332 |
+
antialias=self.encoder.interpolate_antialias,
|
| 333 |
+
size = (int(self.feature_size[1]),int(self.feature_size[1]))).squeeze(0).permute(1,2,0)
|
| 334 |
+
lvl1_pos_embed = torch.cat([full_pos_embed[ys,xs]\
|
| 335 |
+
for (ys,xs) in zip(lvl1_pos_y.split(lvl1_token_lens), lvl1_pos_x.split(lvl1_token_lens))], dim=0)
|
| 336 |
+
lvl1_tokens = lvl1_tokens + lvl1_pos_embed
|
| 337 |
+
|
| 338 |
+
# get scale map (flattened)
|
| 339 |
+
x_list = [torch.cat([cr, lvl1],dim=0).unsqueeze(0)\
|
| 340 |
+
for (cr, lvl1) \
|
| 341 |
+
in zip(cr_token_list, lvl1_tokens.split(lvl1_token_lens))]
|
| 342 |
+
scale_map, updated_cr_list, updated_lvl1 = self.get_scale_map(x_list)
|
| 343 |
+
# for visualization
|
| 344 |
+
scale_map_dict = {'scale_map': scale_map, 'lens': lvl1_token_lens, 'hw': lvl1_feature_hw,
|
| 345 |
+
'pos_y': lvl1_pos_y, 'pos_x': lvl1_pos_x}
|
| 346 |
+
|
| 347 |
+
# get sat masks
|
| 348 |
+
conf_thresh = self.sat_cfg['conf_thresh']
|
| 349 |
+
scale_thresh = self.sat_cfg['scale_thresh']
|
| 350 |
+
if use_gt:
|
| 351 |
+
scale_map = torch.cat([tgt['scale_map'].view(-1,2) for tgt in targets], dim=0)
|
| 352 |
+
|
| 353 |
+
lvl1_valid_mask = scale_map[:,0] > conf_thresh
|
| 354 |
+
lvl1_sat_mask = lvl1_valid_mask & (scale_map[:,1] < scale_thresh)
|
| 355 |
+
|
| 356 |
+
# prepare sat tokens (lvl0)
|
| 357 |
+
lvl0_token_lens = [msk.sum().item()<<2 for msk in lvl1_sat_mask.split(lvl1_token_lens)]
|
| 358 |
+
lvl1_sat_patches_28 = lvl1_img_patches_28[lvl1_sat_mask] # (L0//4, 3, 28, 28)
|
| 359 |
+
lvl0_tokens = self.encoder_patch_norm[0](self.encoder_patch_proj[0](lvl1_sat_patches_28).permute(0, 2, 3, 1).flatten(0,2))
|
| 360 |
+
|
| 361 |
+
assert len(lvl0_tokens) == sum(lvl0_token_lens)
|
| 362 |
+
# lvl0 positions
|
| 363 |
+
lvl0_pos_y, lvl0_pos_x = lvl1_pos_y[lvl1_sat_mask], lvl1_pos_x[lvl1_sat_mask]
|
| 364 |
+
lvl0_pos_y = (lvl0_pos_y<<1)[:,None].repeat(1,4).flatten()
|
| 365 |
+
lvl0_pos_x = (lvl0_pos_x<<1)[:,None].repeat(1,4).flatten()
|
| 366 |
+
lvl0_pos_y[2::4] += 1
|
| 367 |
+
lvl0_pos_y[3::4] += 1
|
| 368 |
+
lvl0_pos_x[1::2] += 1
|
| 369 |
+
assert len(lvl0_pos_x) == len(lvl0_tokens)
|
| 370 |
+
|
| 371 |
+
# lvl0 pos_embed
|
| 372 |
+
full_pos_embed = self.encoder_pos_embeds
|
| 373 |
+
lvl0_pos_embed = torch.cat([full_pos_embed[ys,xs]\
|
| 374 |
+
for (ys,xs) in zip(lvl0_pos_y.split(lvl0_token_lens), lvl0_pos_x.split(lvl0_token_lens))], dim=0)
|
| 375 |
+
lvl0_tokens = lvl0_tokens + lvl0_pos_embed
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
# update tokens
|
| 379 |
+
x_list = [torch.cat([cr, lvl0],dim=0).unsqueeze(0)\
|
| 380 |
+
for (cr, lvl0) \
|
| 381 |
+
in zip(cr_token_list, lvl0_tokens.split(lvl0_token_lens))]
|
| 382 |
+
x_list = self.encoder.forward_specific_layers_list(x_list, end=self.sat_cfg['get_map_layer'], get_feature=False)
|
| 383 |
+
lvl0_tokens = torch.cat([x[:, 1 + self.encoder.num_register_tokens:, :].squeeze(0) for x in x_list], dim=0)
|
| 384 |
+
assert len(lvl0_pos_x) == len(lvl0_tokens)
|
| 385 |
+
# also update lvl1 and crs
|
| 386 |
+
lvl1_tokens = updated_lvl1
|
| 387 |
+
cr_token_list = updated_cr_list
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
if self.sat_cfg['num_lvls'] == 2:
|
| 392 |
+
# drop corresponding lvl1 tokens
|
| 393 |
+
lvl1_keep = ~lvl1_sat_mask
|
| 394 |
+
lvl1_token_lens = [msk.sum().item() for msk in lvl1_keep.split(lvl1_token_lens)]
|
| 395 |
+
lvl1_tokens = lvl1_tokens[lvl1_keep]
|
| 396 |
+
lvl1_pos_y = lvl1_pos_y[lvl1_keep]
|
| 397 |
+
lvl1_pos_x = lvl1_pos_x[lvl1_keep]
|
| 398 |
+
|
| 399 |
+
# normalize positions
|
| 400 |
+
lvl0_pos_y_norm = (lvl0_pos_y.to(dtype=lvl0_tokens.dtype) + 0.5) / self.feature_size[0]
|
| 401 |
+
lvl0_pos_x_norm = (lvl0_pos_x.to(dtype=lvl0_tokens.dtype) + 0.5) / self.feature_size[0]
|
| 402 |
+
lvl1_pos_y_norm = (lvl1_pos_y.to(dtype=lvl1_tokens.dtype) + 0.5) / self.feature_size[1]
|
| 403 |
+
lvl1_pos_x_norm = (lvl1_pos_x.to(dtype=lvl1_tokens.dtype) + 0.5) / self.feature_size[1]
|
| 404 |
+
|
| 405 |
+
# merge all
|
| 406 |
+
x_list = [torch.cat([cr, lvl0, lvl1]).unsqueeze(0) \
|
| 407 |
+
for cr, lvl0, lvl1 \
|
| 408 |
+
in zip(cr_token_list, lvl0_tokens.split(lvl0_token_lens), lvl1_tokens.split(lvl1_token_lens))]
|
| 409 |
+
pos_y_list = [torch.cat([lvl0, lvl1]) \
|
| 410 |
+
for lvl0, lvl1 \
|
| 411 |
+
in zip(lvl0_pos_y_norm.split(lvl0_token_lens), lvl1_pos_y_norm.split(lvl1_token_lens))]
|
| 412 |
+
pos_x_list = [torch.cat([lvl0, lvl1]) \
|
| 413 |
+
for lvl0, lvl1 \
|
| 414 |
+
in zip(lvl0_pos_x_norm.split(lvl0_token_lens), lvl1_pos_x_norm.split(lvl1_token_lens))]
|
| 415 |
+
lvl_list = [torch.cat([torch.zeros_like(lvl0, dtype=int), torch.ones_like(lvl1, dtype=int)]) \
|
| 416 |
+
for lvl0, lvl1 \
|
| 417 |
+
in zip(lvl0_pos_x_norm.split(lvl0_token_lens), lvl1_pos_x_norm.split(lvl1_token_lens))]
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
else:
|
| 421 |
+
# prune lvl1 correspond to lvl0
|
| 422 |
+
lvl1_valid_mask = self.pad_mask(lvl1_valid_mask)
|
| 423 |
+
lvl1_keep = lvl1_valid_mask & (~lvl1_sat_mask)
|
| 424 |
+
lvl1_to_lvl2 = ~lvl1_valid_mask
|
| 425 |
+
|
| 426 |
+
token_lvls = [lvl0_tokens, lvl1_tokens]
|
| 427 |
+
token_lens_lvls = [lvl0_token_lens, lvl1_token_lens]
|
| 428 |
+
pos_y_lvls = [lvl0_pos_y, lvl1_pos_y]
|
| 429 |
+
pos_x_lvls = [lvl0_pos_x, lvl1_pos_x]
|
| 430 |
+
|
| 431 |
+
to_next_lvl = lvl1_to_lvl2
|
| 432 |
+
keep = lvl1_keep
|
| 433 |
+
lvl_zorders = lvl1_zorders
|
| 434 |
+
lvl_bids = lvl1_bids
|
| 435 |
+
pad_vals = torch.full((3,), -1, dtype=lvl_zorders.dtype, device=lvl_zorders.device)
|
| 436 |
+
for lvl in range(self.sat_cfg['num_lvls']-2):
|
| 437 |
+
if to_next_lvl.sum() == 0:
|
| 438 |
+
break
|
| 439 |
+
next_tokens = self.lvl_pooling(token_lvls[-1][to_next_lvl])
|
| 440 |
+
# next_tokens = torch.max(token_lvls[-1][to_next_lvl].view(-1,4,C), dim=1)[0]
|
| 441 |
+
next_pos_y = pos_y_lvls[-1][to_next_lvl][::4]>>1
|
| 442 |
+
next_pos_x = pos_x_lvls[-1][to_next_lvl][::4]>>1
|
| 443 |
+
next_lens = [msk.sum().item()//4 for msk in to_next_lvl.split(token_lens_lvls[-1])]
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
token_lvls[-1] = token_lvls[-1][keep]
|
| 447 |
+
pos_y_lvls[-1] = pos_y_lvls[-1][keep]
|
| 448 |
+
pos_x_lvls[-1] = pos_x_lvls[-1][keep]
|
| 449 |
+
token_lens_lvls[-1] = [msk.sum().item() for msk in keep.split(token_lens_lvls[-1])]
|
| 450 |
+
|
| 451 |
+
token_lvls.append(next_tokens)
|
| 452 |
+
token_lens_lvls.append(next_lens)
|
| 453 |
+
pos_y_lvls.append(next_pos_y)
|
| 454 |
+
pos_x_lvls.append(next_pos_x)
|
| 455 |
+
|
| 456 |
+
if lvl < self.sat_cfg['num_lvls']-3:
|
| 457 |
+
lvl_zorders = lvl_zorders[to_next_lvl][::4]>>2
|
| 458 |
+
lvl_bids = lvl_bids[to_next_lvl][::4]
|
| 459 |
+
|
| 460 |
+
z_starts_idx = torch.where((lvl_zorders&3)==0)[0]
|
| 461 |
+
padded_z = torch.cat([lvl_zorders, pad_vals])
|
| 462 |
+
padded_bids = torch.cat([lvl_bids, pad_vals])
|
| 463 |
+
valids = (padded_z[z_starts_idx] + 3 == padded_z[z_starts_idx + 3]) & (padded_bids[z_starts_idx] == padded_bids[z_starts_idx + 3])
|
| 464 |
+
valid_starts = z_starts_idx[valids]
|
| 465 |
+
|
| 466 |
+
to_next_lvl = torch.zeros_like(lvl_zorders, dtype=bool)
|
| 467 |
+
to_next_lvl[valid_starts] = True
|
| 468 |
+
to_next_lvl[valid_starts+1] = True
|
| 469 |
+
to_next_lvl[valid_starts+2] = True
|
| 470 |
+
to_next_lvl[valid_starts+3] = True
|
| 471 |
+
|
| 472 |
+
keep = ~to_next_lvl
|
| 473 |
+
|
| 474 |
+
norm_pos_y_lvls = [(pos_y.to(dtype=lvl0_tokens.dtype) + 0.5)/self.feature_size[i] for i, pos_y in enumerate(pos_y_lvls)]
|
| 475 |
+
norm_pos_x_lvls = [(pos_x.to(dtype=lvl0_tokens.dtype) + 0.5)/self.feature_size[i] for i, pos_x in enumerate(pos_x_lvls)]
|
| 476 |
+
|
| 477 |
+
x_list = [torch.cat([cr, *lvls]).unsqueeze(0) \
|
| 478 |
+
for cr, *lvls \
|
| 479 |
+
in zip(cr_token_list, *[tokens.split(lens) for (tokens, lens) in zip(token_lvls, token_lens_lvls)])]
|
| 480 |
+
pos_y_list = [torch.cat([*lvls]) \
|
| 481 |
+
for lvls \
|
| 482 |
+
in zip(*[pos_y.split(lens) for (pos_y, lens) in zip(norm_pos_y_lvls, token_lens_lvls)])]
|
| 483 |
+
pos_x_list = [torch.cat([*lvls]) \
|
| 484 |
+
for lvls \
|
| 485 |
+
in zip(*[pos_x.split(lens) for (pos_x, lens) in zip(norm_pos_x_lvls, token_lens_lvls)])]
|
| 486 |
+
lvl_list = [torch.cat([torch.full_like(lvl, i, dtype=torch.int64) for i, lvl in enumerate(lvls)]) \
|
| 487 |
+
for lvls \
|
| 488 |
+
in zip(*[pos_x.split(lens) for (pos_x, lens) in zip(norm_pos_x_lvls, token_lens_lvls)])]
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
start = self.sat_cfg['get_map_layer'] if self.use_sat else 0
|
| 493 |
+
_, final_feature_list = self.encoder.forward_specific_layers_list(x_list, start = start, norm=True)
|
| 494 |
+
|
| 495 |
+
# proj
|
| 496 |
+
token_lens = [feature.shape[1] for feature in final_feature_list]
|
| 497 |
+
final_features = self.feature_proj(torch.cat(final_feature_list,dim=1).squeeze(0)) # (sum(L), C)
|
| 498 |
+
assert tuple(final_features.shape) == (sum(token_lens), self.hidden_dim)
|
| 499 |
+
# positional encoding
|
| 500 |
+
pos_embeds = position_encoding_xy(torch.cat(pos_x_list,dim=0), torch.cat(pos_y_list,dim=0), embedding_dim=self.hidden_dim)
|
| 501 |
+
if self.use_sat and self.sat_cfg['lvl_embed']:
|
| 502 |
+
lvl_embeds = self.level_embed[torch.cat(lvl_list,dim=0)]
|
| 503 |
+
pos_embeds = pos_embeds + lvl_embeds
|
| 504 |
+
|
| 505 |
+
sat_dict = {'pos_y': pos_y_list, 'pos_x': pos_x_list, 'lvl': lvl_list,
|
| 506 |
+
# 'features': [feature.squeeze(0) for feature in final_feature_list],
|
| 507 |
+
'lens': token_lens}
|
| 508 |
+
|
| 509 |
+
return final_features, pos_embeds, token_lens, scale_map_dict, sat_dict
|
| 510 |
+
|
| 511 |
+
def process_smpl(self, poses, shapes, cam_xys, cam_intrinsics, detach_j3ds = False):
|
| 512 |
+
bs, num_queries, _ = poses.shape # should be (bs,n_q,num_poses*3)
|
| 513 |
+
|
| 514 |
+
# flatten and compute
|
| 515 |
+
poses = poses.flatten(0,1) # (bs*n_q,24*3)
|
| 516 |
+
shapes = shapes.flatten(0,1) # (bs*n_q,10)
|
| 517 |
+
verts, joints = self.human_model(poses=poses,
|
| 518 |
+
betas=shapes)
|
| 519 |
+
num_verts = verts.shape[1]
|
| 520 |
+
num_joints = joints.shape[1]
|
| 521 |
+
verts = verts.reshape(bs,num_queries,num_verts,3)
|
| 522 |
+
joints = joints.reshape(bs,num_queries,num_joints,3)
|
| 523 |
+
|
| 524 |
+
# apply cam_trans and projection
|
| 525 |
+
scale = 2*cam_xys[:,:,2:].sigmoid() + 1e-6
|
| 526 |
+
t_xy = cam_xys[:,:,:2]/scale
|
| 527 |
+
t_z = (2*self.focal)/(scale*self.input_size) # (bs,num_queries,1)
|
| 528 |
+
transl = torch.cat([t_xy,t_z],dim=2)[:,:,None,:] # (bs,nq,1,3)
|
| 529 |
+
|
| 530 |
+
verts_cam = verts + transl # only for visualization and evaluation
|
| 531 |
+
j3ds_cam = joints + transl
|
| 532 |
+
|
| 533 |
+
if detach_j3ds:
|
| 534 |
+
j2ds_homo = torch.matmul(joints.detach() + transl, cam_intrinsics.transpose(2,3))
|
| 535 |
+
else:
|
| 536 |
+
j2ds_homo = torch.matmul(j3ds_cam, cam_intrinsics.transpose(2,3))
|
| 537 |
+
j2ds_img = (j2ds_homo[..., :2] / (j2ds_homo[..., 2, None] + 1e-6)).reshape(bs,num_queries,num_joints,2)
|
| 538 |
+
|
| 539 |
+
depths = j3ds_cam[:,:,0,2:] # (bs, n_q, 1)
|
| 540 |
+
depths = torch.cat([depths, depths/self.focal], dim=-1) # (bs, n_q, 2)
|
| 541 |
+
|
| 542 |
+
return verts_cam, j3ds_cam, j2ds_img, depths, transl.flatten(2)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def forward(self, samples: NestedTensor, targets, sat_use_gt = False, detach_j3ds = False):
|
| 546 |
+
""" The forward expects a NestedTensor, which consists of:
|
| 547 |
+
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
|
| 548 |
+
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
|
| 549 |
+
|
| 550 |
+
It returns a dict with the following elements:
|
| 551 |
+
- "pred_logits": the classification logits (including no-object) for all queries.
|
| 552 |
+
Shape= [batch_size x num_queries x num_classes]
|
| 553 |
+
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
|
| 554 |
+
(center_x, center_y, width, height). These values are normalized in [0, 1],
|
| 555 |
+
relative to the size of each individual image (disregarding possible padding).
|
| 556 |
+
See PostProcess for information on how to retrieve the unnormalized bounding box.
|
| 557 |
+
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
|
| 558 |
+
dictionnaries containing the two above keys for each decoder layer.
|
| 559 |
+
"""
|
| 560 |
+
|
| 561 |
+
assert isinstance(samples, (list, torch.Tensor))
|
| 562 |
+
|
| 563 |
+
if self.training:
|
| 564 |
+
self.preprocessed_pos_lvl1 = None
|
| 565 |
+
|
| 566 |
+
elif self.preprocessed_pos_lvl1 is None and self.use_sat:
|
| 567 |
+
self.preprocessed_pos_lvl1 = F.interpolate(self.encoder_pos_embeds.unsqueeze(0).permute(0, 3, 1, 2),
|
| 568 |
+
mode="bicubic",
|
| 569 |
+
antialias=self.encoder.interpolate_antialias,
|
| 570 |
+
size = (int(self.feature_size[1]),int(self.feature_size[1]))).squeeze(0).permute(1,2,0)
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
bs = len(targets)
|
| 574 |
+
|
| 575 |
+
# get cam_intrinsics
|
| 576 |
+
img_size = torch.stack([t['img_size'].flip(0) for t in targets])
|
| 577 |
+
valid_ratio = img_size/self.input_size
|
| 578 |
+
|
| 579 |
+
cam_intrinsics = self.cam_intrinsics.repeat(bs, 1, 1, 1)
|
| 580 |
+
cam_intrinsics[...,:2,2] = cam_intrinsics[...,:2,2] * valid_ratio[:, None, :]
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
final_features, pos_embeds, token_lens, scale_map_dict, sat_dict\
|
| 584 |
+
= self.forward_encoder(samples, targets, use_gt = sat_use_gt)
|
| 585 |
+
|
| 586 |
+
# default dab-detr pipeline
|
| 587 |
+
embedweight = (self.refpoint_embed.weight).unsqueeze(0).repeat(bs,1,1)
|
| 588 |
+
tgt = (self.tgt_embed.weight).unsqueeze(0).repeat(bs,1,1)
|
| 589 |
+
|
| 590 |
+
if self.training and self.use_dn:
|
| 591 |
+
input_query_tgt, input_query_bbox, attn_mask, dn_meta =\
|
| 592 |
+
prepare_for_cdn(targets = targets, dn_cfg = self.dn_cfg,
|
| 593 |
+
num_queries = self.num_queries, hidden_dim = self.hidden_dim, dn_enc = self.dn_enc)
|
| 594 |
+
tgt = torch.cat([input_query_tgt, tgt], dim=1)
|
| 595 |
+
embedweight = torch.cat([input_query_bbox, embedweight], dim=1)
|
| 596 |
+
else:
|
| 597 |
+
attn_mask = None
|
| 598 |
+
|
| 599 |
+
tgt_lens = [tgt.shape[1]]*bs
|
| 600 |
+
|
| 601 |
+
hs, reference = self.decoder(memory=final_features, memory_lens=token_lens,
|
| 602 |
+
tgt=tgt.flatten(0,1), tgt_lens=tgt_lens,
|
| 603 |
+
refpoint_embed=embedweight.flatten(0,1),
|
| 604 |
+
pos_embed=pos_embeds,
|
| 605 |
+
self_attn_mask = attn_mask)
|
| 606 |
+
|
| 607 |
+
reference_before_sigmoid = inverse_sigmoid(reference)
|
| 608 |
+
outputs_coords = []
|
| 609 |
+
for lvl in range(hs.shape[0]):
|
| 610 |
+
tmp = self.bbox_embed[lvl](hs[lvl])
|
| 611 |
+
tmp[..., :self.query_dim] += reference_before_sigmoid[lvl]
|
| 612 |
+
outputs_coord = tmp.sigmoid()
|
| 613 |
+
outputs_coords.append(outputs_coord)
|
| 614 |
+
pred_boxes = torch.stack(outputs_coords)
|
| 615 |
+
|
| 616 |
+
outputs_poses = []
|
| 617 |
+
outputs_shapes = []
|
| 618 |
+
outputs_confs = []
|
| 619 |
+
outputs_j3ds = []
|
| 620 |
+
outputs_j2ds = []
|
| 621 |
+
outputs_depths = []
|
| 622 |
+
|
| 623 |
+
# shape of hs: (lvl, bs, num_queries, dim)
|
| 624 |
+
outputs_pose_6d = self.mean_pose.view(1, 1, -1)
|
| 625 |
+
outputs_shape = self.mean_shape.view(1, 1, -1)
|
| 626 |
+
for lvl in range(hs.shape[0]):
|
| 627 |
+
|
| 628 |
+
outputs_pose_6d = outputs_pose_6d + self.pose_head[lvl](hs[lvl])
|
| 629 |
+
outputs_shape = outputs_shape + self.shape_head[lvl](hs[lvl])
|
| 630 |
+
|
| 631 |
+
if self.training or lvl == hs.shape[0] - 1:
|
| 632 |
+
outputs_pose = rot6d_to_axis_angle(outputs_pose_6d)
|
| 633 |
+
|
| 634 |
+
outputs_conf = self.conf_head(hs[lvl]).sigmoid()
|
| 635 |
+
|
| 636 |
+
# cam
|
| 637 |
+
cam_xys = self.cam_head(hs[lvl])
|
| 638 |
+
|
| 639 |
+
outputs_vert, outputs_j3d, outputs_j2d, depth, transl\
|
| 640 |
+
= self.process_smpl(poses = outputs_pose,
|
| 641 |
+
shapes = outputs_shape,
|
| 642 |
+
cam_xys = cam_xys,
|
| 643 |
+
cam_intrinsics = cam_intrinsics,
|
| 644 |
+
detach_j3ds = detach_j3ds)
|
| 645 |
+
|
| 646 |
+
outputs_poses.append(outputs_pose)
|
| 647 |
+
outputs_shapes.append(outputs_shape)
|
| 648 |
+
outputs_confs.append(outputs_conf)
|
| 649 |
+
# outputs_verts.append(outputs_vert)
|
| 650 |
+
outputs_j3ds.append(outputs_j3d)
|
| 651 |
+
outputs_j2ds.append(outputs_j2d)
|
| 652 |
+
outputs_depths.append(depth)
|
| 653 |
+
|
| 654 |
+
pred_poses = torch.stack(outputs_poses)
|
| 655 |
+
pred_betas = torch.stack(outputs_shapes)
|
| 656 |
+
pred_confs = torch.stack(outputs_confs)
|
| 657 |
+
pred_verts = outputs_vert
|
| 658 |
+
pred_transl = transl
|
| 659 |
+
pred_intrinsics = cam_intrinsics
|
| 660 |
+
pred_j3ds = torch.stack(outputs_j3ds)
|
| 661 |
+
pred_j2ds = torch.stack(outputs_j2ds)
|
| 662 |
+
pred_depths = torch.stack(outputs_depths)
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
if self.training > 0 and self.use_dn:
|
| 667 |
+
pred_poses, pred_betas,\
|
| 668 |
+
pred_boxes, pred_confs,\
|
| 669 |
+
pred_j3ds, pred_j2ds, pred_depths,\
|
| 670 |
+
pred_verts, pred_transl =\
|
| 671 |
+
dn_post_process(pred_poses, pred_betas,
|
| 672 |
+
pred_boxes, pred_confs,
|
| 673 |
+
pred_j3ds, pred_j2ds, pred_depths,
|
| 674 |
+
pred_verts, pred_transl,
|
| 675 |
+
dn_meta, self.aux_loss, self._set_aux_loss)
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
out = {'pred_poses': pred_poses[-1], 'pred_betas': pred_betas[-1],
|
| 679 |
+
'pred_boxes': pred_boxes[-1], 'pred_confs': pred_confs[-1],
|
| 680 |
+
'pred_j3ds': pred_j3ds[-1], 'pred_j2ds': pred_j2ds[-1],
|
| 681 |
+
'pred_verts': pred_verts, 'pred_intrinsics': pred_intrinsics,
|
| 682 |
+
'pred_depths': pred_depths[-1], 'pred_transl': pred_transl}
|
| 683 |
+
|
| 684 |
+
if self.aux_loss and self.training:
|
| 685 |
+
out['aux_outputs'] = self._set_aux_loss(pred_poses, pred_betas,
|
| 686 |
+
pred_boxes, pred_confs,
|
| 687 |
+
pred_j3ds, pred_j2ds, pred_depths)
|
| 688 |
+
|
| 689 |
+
if self.use_sat:
|
| 690 |
+
out['enc_outputs'] = scale_map_dict
|
| 691 |
+
|
| 692 |
+
out['sat'] = sat_dict
|
| 693 |
+
|
| 694 |
+
if self.training > 0 and self.use_dn:
|
| 695 |
+
out['dn_meta'] = dn_meta
|
| 696 |
+
|
| 697 |
+
return out
|
| 698 |
+
|
| 699 |
+
@torch.jit.unused
|
| 700 |
+
def _set_aux_loss(self, pred_poses, pred_betas, pred_boxes,
|
| 701 |
+
pred_confs, pred_j3ds,
|
| 702 |
+
pred_j2ds, pred_depths):
|
| 703 |
+
# this is a workaround to make torchscript happy, as torchscript
|
| 704 |
+
# doesn't support dictionary with non-homogeneous values, such
|
| 705 |
+
# as a dict having both a Tensor and a list.
|
| 706 |
+
return [{'pred_poses': a, 'pred_betas': b,
|
| 707 |
+
'pred_boxes': c, 'pred_confs': d,
|
| 708 |
+
'pred_j3ds': e, 'pred_j2ds': f, 'pred_depths': g}
|
| 709 |
+
for a, b, c, d, e, f, g in zip(pred_poses[:-1], pred_betas[:-1],
|
| 710 |
+
pred_boxes[:-1], pred_confs[:-1], pred_j3ds[:-1], pred_j2ds[:-1], pred_depths[:-1])]
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
class MLP(nn.Module):
|
| 715 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
| 716 |
+
|
| 717 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
| 718 |
+
super().__init__()
|
| 719 |
+
self.num_layers = num_layers
|
| 720 |
+
h = [hidden_dim] * (num_layers - 1)
|
| 721 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
| 722 |
+
|
| 723 |
+
def forward(self, x):
|
| 724 |
+
for i, layer in enumerate(self.layers):
|
| 725 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
| 726 |
+
return x
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
def build_sat_model(args, set_criterion=True):
|
| 730 |
+
encoder = build_encoder(args)
|
| 731 |
+
decoder = build_decoder(args)
|
| 732 |
+
|
| 733 |
+
model = Model(
|
| 734 |
+
encoder,
|
| 735 |
+
decoder,
|
| 736 |
+
num_queries=args.num_queries,
|
| 737 |
+
input_size=args.input_size,
|
| 738 |
+
sat_cfg=args.sat_cfg,
|
| 739 |
+
dn_cfg=args.dn_cfg,
|
| 740 |
+
train_pos_embed=getattr(args,'train_pos_embed',True)
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
if set_criterion:
|
| 745 |
+
matcher = build_matcher(args)
|
| 746 |
+
weight_dict = args.weight_dict
|
| 747 |
+
losses = args.losses
|
| 748 |
+
|
| 749 |
+
if args.dn_cfg['use_dn']:
|
| 750 |
+
dn_weight_dict = {}
|
| 751 |
+
dn_weight_dict.update({f'{k}_dn': v for k, v in weight_dict.items()})
|
| 752 |
+
weight_dict.update(dn_weight_dict)
|
| 753 |
+
|
| 754 |
+
aux_weight_dict = {}
|
| 755 |
+
for i in range(args.dec_layers - 1):
|
| 756 |
+
aux_weight_dict.update({f'{k}.{i}': v for k, v in weight_dict.items()})
|
| 757 |
+
weight_dict.update(aux_weight_dict)
|
| 758 |
+
|
| 759 |
+
if args.sat_cfg['use_sat']:
|
| 760 |
+
if 'map_confs' not in weight_dict:
|
| 761 |
+
weight_dict.update({'map_confs': weight_dict['confs']})
|
| 762 |
+
# weight_dict.update({'map_scales': })
|
| 763 |
+
|
| 764 |
+
criterion = SetCriterion(matcher, weight_dict, losses = losses, j2ds_norm_scale = args.input_size)
|
| 765 |
+
return model, criterion
|
| 766 |
+
else:
|
| 767 |
+
return model, None
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==0.26.1
|
| 2 |
+
chumpy
|
| 3 |
+
smplx
|
| 4 |
+
opencv-python
|
| 5 |
+
trimesh
|
| 6 |
+
tensorboard
|
| 7 |
+
scipy
|
| 8 |
+
pyrender==0.1.45
|
| 9 |
+
joblib
|
| 10 |
+
termcolor
|
| 11 |
+
transformers
|
| 12 |
+
matplotlib
|
| 13 |
+
scikit-learn
|
utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
utils/box_ops.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
"""
|
| 3 |
+
Utilities for bounding box manipulation and GIoU.
|
| 4 |
+
"""
|
| 5 |
+
import torch, os
|
| 6 |
+
from torchvision.ops.boxes import box_area
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def box_cxcywh_to_xyxy(x):
|
| 10 |
+
x_c, y_c, w, h = x.unbind(-1)
|
| 11 |
+
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
| 12 |
+
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
| 13 |
+
return torch.stack(b, dim=-1)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def box_xyxy_to_cxcywh(x):
|
| 17 |
+
x0, y0, x1, y1 = x.unbind(-1)
|
| 18 |
+
b = [(x0 + x1) / 2, (y0 + y1) / 2,
|
| 19 |
+
(x1 - x0), (y1 - y0)]
|
| 20 |
+
return torch.stack(b, dim=-1)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# modified from torchvision to also return the union
|
| 24 |
+
def box_iou(boxes1, boxes2):
|
| 25 |
+
area1 = box_area(boxes1)
|
| 26 |
+
area2 = box_area(boxes2)
|
| 27 |
+
|
| 28 |
+
# import ipdb; ipdb.set_trace()
|
| 29 |
+
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
| 30 |
+
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
| 31 |
+
|
| 32 |
+
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
| 33 |
+
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
| 34 |
+
|
| 35 |
+
union = area1[:, None] + area2 - inter
|
| 36 |
+
|
| 37 |
+
iou = inter / (union + 1e-6)
|
| 38 |
+
return iou, union
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def generalized_box_iou(boxes1, boxes2):
|
| 42 |
+
"""
|
| 43 |
+
Generalized IoU from https://giou.stanford.edu/
|
| 44 |
+
|
| 45 |
+
The boxes should be in [x0, y0, x1, y1] format
|
| 46 |
+
|
| 47 |
+
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
| 48 |
+
and M = len(boxes2)
|
| 49 |
+
"""
|
| 50 |
+
# degenerate boxes gives inf / nan results
|
| 51 |
+
# so do an early check
|
| 52 |
+
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
| 53 |
+
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
| 54 |
+
# except:
|
| 55 |
+
# import ipdb; ipdb.set_trace()
|
| 56 |
+
iou, union = box_iou(boxes1, boxes2)
|
| 57 |
+
|
| 58 |
+
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
| 59 |
+
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
| 60 |
+
|
| 61 |
+
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
| 62 |
+
area = wh[:, :, 0] * wh[:, :, 1]
|
| 63 |
+
|
| 64 |
+
return iou - (area - union) / (area + 1e-6)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# modified from torchvision to also return the union
|
| 69 |
+
def box_iou_pairwise(boxes1, boxes2):
|
| 70 |
+
area1 = box_area(boxes1)
|
| 71 |
+
area2 = box_area(boxes2)
|
| 72 |
+
|
| 73 |
+
lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2]
|
| 74 |
+
rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2]
|
| 75 |
+
|
| 76 |
+
wh = (rb - lt).clamp(min=0) # [N,2]
|
| 77 |
+
inter = wh[:, 0] * wh[:, 1] # [N]
|
| 78 |
+
|
| 79 |
+
union = area1 + area2 - inter
|
| 80 |
+
|
| 81 |
+
iou = inter / union
|
| 82 |
+
return iou, union
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def generalized_box_iou_pairwise(boxes1, boxes2):
|
| 86 |
+
"""
|
| 87 |
+
Generalized IoU from https://giou.stanford.edu/
|
| 88 |
+
|
| 89 |
+
Input:
|
| 90 |
+
- boxes1, boxes2: N,4
|
| 91 |
+
Output:
|
| 92 |
+
- giou: N, 4
|
| 93 |
+
"""
|
| 94 |
+
# degenerate boxes gives inf / nan results
|
| 95 |
+
# so do an early check
|
| 96 |
+
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
| 97 |
+
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
| 98 |
+
assert boxes1.shape == boxes2.shape
|
| 99 |
+
iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4
|
| 100 |
+
|
| 101 |
+
lt = torch.min(boxes1[:, :2], boxes2[:, :2])
|
| 102 |
+
rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])
|
| 103 |
+
|
| 104 |
+
wh = (rb - lt).clamp(min=0) # [N,2]
|
| 105 |
+
area = wh[:, 0] * wh[:, 1]
|
| 106 |
+
|
| 107 |
+
return iou - (area - union) / area
|
| 108 |
+
|
| 109 |
+
def masks_to_boxes(masks):
|
| 110 |
+
"""Compute the bounding boxes around the provided masks
|
| 111 |
+
|
| 112 |
+
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
|
| 113 |
+
|
| 114 |
+
Returns a [N, 4] tensors, with the boxes in xyxy format
|
| 115 |
+
"""
|
| 116 |
+
if masks.numel() == 0:
|
| 117 |
+
return torch.zeros((0, 4), device=masks.device)
|
| 118 |
+
|
| 119 |
+
h, w = masks.shape[-2:]
|
| 120 |
+
|
| 121 |
+
y = torch.arange(0, h, dtype=torch.float)
|
| 122 |
+
x = torch.arange(0, w, dtype=torch.float)
|
| 123 |
+
y, x = torch.meshgrid(y, x)
|
| 124 |
+
|
| 125 |
+
x_mask = (masks * x.unsqueeze(0))
|
| 126 |
+
x_max = x_mask.flatten(1).max(-1)[0]
|
| 127 |
+
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
| 128 |
+
|
| 129 |
+
y_mask = (masks * y.unsqueeze(0))
|
| 130 |
+
y_max = y_mask.flatten(1).max(-1)[0]
|
| 131 |
+
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
| 132 |
+
|
| 133 |
+
return torch.stack([x_min, y_min, x_max, y_max], 1)
|
| 134 |
+
|
| 135 |
+
if __name__ == '__main__':
|
| 136 |
+
x = torch.rand(5, 4)
|
| 137 |
+
y = torch.rand(3, 4)
|
| 138 |
+
iou, union = box_iou(x, y)
|
| 139 |
+
import ipdb; ipdb.set_trace()
|