ChiSu001 commited on
Commit
ff07ed4
·
verified ·
1 Parent(s): 3063eac

Upload model files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. .gitignore +9 -0
  3. configs/__init__.py +0 -0
  4. configs/models/default.yaml +29 -0
  5. configs/paths.py +7 -0
  6. configs/run/demo.yaml +12 -0
  7. datasets/__init__.py +0 -0
  8. datasets/agora.py +111 -0
  9. datasets/base.py +521 -0
  10. datasets/bedlam.py +72 -0
  11. datasets/common.py +34 -0
  12. datasets/multiple_datasets.py +49 -0
  13. demo/img0.png +3 -0
  14. demo/img1.jpeg +0 -0
  15. demo/img2.jpg +3 -0
  16. docs/fix_chumpy.md +44 -0
  17. engines/__init__.py +0 -0
  18. engines/engine.py +347 -0
  19. engines/funcs/__init__.py +0 -0
  20. engines/funcs/eval_funcs.py +362 -0
  21. engines/funcs/infer_funcs.py +86 -0
  22. figures/pipeline.png +3 -0
  23. figures/qualitative_results.png +3 -0
  24. figures/results.png +3 -0
  25. figures/results_3d.gif +3 -0
  26. main.py +52 -0
  27. models/__init__.py +16 -0
  28. models/criterion.py +449 -0
  29. models/decoder.py +388 -0
  30. models/dn_components.py +193 -0
  31. models/encoders/__init__.py +52 -0
  32. models/encoders/dinov2/layers/__init__.py +11 -0
  33. models/encoders/dinov2/layers/attention.py +89 -0
  34. models/encoders/dinov2/layers/block.py +260 -0
  35. models/encoders/dinov2/layers/dino_head.py +58 -0
  36. models/encoders/dinov2/layers/drop_path.py +34 -0
  37. models/encoders/dinov2/layers/layer_scale.py +27 -0
  38. models/encoders/dinov2/layers/mlp.py +40 -0
  39. models/encoders/dinov2/layers/patch_embed.py +88 -0
  40. models/encoders/dinov2/layers/swiglu_ffn.py +72 -0
  41. models/encoders/dinov2/models/__init__.py +43 -0
  42. models/encoders/dinov2/models/vision_transformer.py +542 -0
  43. models/human_models/__init__.py +1 -0
  44. models/human_models/smpl_models.py +69 -0
  45. models/matcher.py +159 -0
  46. models/position_encoding.py +155 -0
  47. models/sat_model.py +767 -0
  48. requirements.txt +13 -0
  49. utils/__init__.py +1 -0
  50. utils/box_ops.py +139 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ demo/img0.png filter=lfs diff=lfs merge=lfs -text
37
+ demo/img2.jpg filter=lfs diff=lfs merge=lfs -text
38
+ figures/pipeline.png filter=lfs diff=lfs merge=lfs -text
39
+ figures/qualitative_results.png filter=lfs diff=lfs merge=lfs -text
40
+ figures/results_3d.gif filter=lfs diff=lfs merge=lfs -text
41
+ figures/results.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ debug_datas.py
2
+ outputs/
3
+ weights/
4
+ results/
5
+ tmps/
6
+ **.out
7
+ **/__pycache__/
8
+ datasets_visualization/
9
+ demo_results/
configs/__init__.py ADDED
File without changes
configs/models/default.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_size: 1288
2
+ encoder: 'vitb'
3
+
4
+ # decoder
5
+ hidden_dim: 768
6
+ nheads: 4
7
+ dec_layers: 6
8
+ dim_feedforward: 2048
9
+ dropout: 0.0
10
+ num_queries: 50
11
+ transformer_activation: "relu"
12
+
13
+ sat_cfg:
14
+ use_sat: True
15
+ share_patch_embed: False
16
+ preprocess_pos_embed: False
17
+ num_lvls: 3
18
+ lvl_embed: True
19
+ get_map_layer: 3
20
+ use_additional_blocks: True
21
+ conf_thresh: 0.3
22
+ scale_thresh: 0.5
23
+
24
+ dn_cfg:
25
+ use_dn: True
26
+ dn_number: 10
27
+ tgt_embed_type: "params"
28
+ box_noise_scale: 0.4
29
+ tgt_noise_scale: 0.2
configs/paths.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ dataset_root = '../datasets'
2
+
3
+ smpl_model_path = './weights/smpl_data'
4
+ smpl_mean_path = './weights/smpl_data/smpl/smpl_mean_params.npz'
5
+
6
+ dinov2_vitb14_path = './weights/dinov2/dinov2_vitb14_pretrain.pth'
7
+ dinov2_vitl14_path = './weights/dinov2/dinov2_vitl14_pretrain.pth'
configs/run/demo.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: default
2
+
3
+ pretrain: True
4
+ pretrain_path: './weights/sat_hmr/sat_644.pth'
5
+
6
+ input_dir: './demo'
7
+ output_dir: './demo_results'
8
+ conf_thresh: [0.3]
9
+ infer_batch_size: 1
10
+ infer_num_workers: 8
11
+ distributed_infer: True
12
+
datasets/__init__.py ADDED
File without changes
datasets/agora.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.utils.data.dataset import Dataset
4
+ import os
5
+ from configs.paths import dataset_root
6
+ import copy
7
+ from tqdm import tqdm
8
+ from .base import BASE
9
+
10
+
11
+ class AGORA(BASE):
12
+ def __init__(self, split='train', **kwargs):
13
+ super(AGORA, self).__init__(**kwargs)
14
+ assert split in ['train','test','validation']
15
+
16
+ self.ds_name = 'agora'
17
+ self.split = split
18
+ self.dataset_path = os.path.join(dataset_root,'agora')
19
+
20
+ # no annotations are available for AGORA-test
21
+ if split == 'test':
22
+ self.mode = 'infer'
23
+ self.img_names = os.listdir(os.path.join(self.dataset_path, self.split))
24
+ else:
25
+ if self.split == 'train':
26
+ annots_path = os.path.join(self.dataset_path,'smpl_neutral_annots','annots_smpl_{}_fit.npz'.format(split))
27
+ else:
28
+ annots_path = os.path.join(self.dataset_path,'smpl_neutral_annots','annots_smpl_{}.npz'.format(split))
29
+ self.annots = np.load(annots_path, allow_pickle=True)['annots'][()]
30
+ self.img_names = list(self.annots.keys())
31
+
32
+ def __len__(self):
33
+ return len(self.img_names)
34
+
35
+ def get_raw_data(self, idx):
36
+ img_id = idx % len(self.img_names)
37
+ img_name = self.img_names[img_id]
38
+
39
+ if self.mode == 'infer':
40
+ img_path = os.path.join(self.dataset_path, self.split,img_name)
41
+ raw_data = {'img_path': img_path,
42
+ 'img_name': img_name,
43
+ 'ds': 'agora'
44
+ }
45
+ return raw_data
46
+
47
+
48
+ annots = copy.deepcopy(self.annots[img_name])
49
+ img_path = os.path.join(self.dataset_path, self.split,img_name)
50
+
51
+ valid_mask = np.where(annots['isValid'])[0]
52
+
53
+ # this should not happen
54
+ if len(valid_mask) ==0:
55
+ print(img_name, 'lack valid person')
56
+ exit(0)
57
+
58
+ cam_intrinsics = torch.from_numpy(np.array(annots['cam_intrinsics']))
59
+ cam_rot = torch.from_numpy(np.array(annots['cam_rot'])[valid_mask])
60
+ cam_trans = torch.from_numpy(np.array(annots['cam_trans'])[valid_mask])
61
+
62
+ betas_list = []
63
+ poses_list = []
64
+ transl_list = []
65
+
66
+ kid = []
67
+
68
+ if self.mode == 'eval':
69
+ occ_leval_list = []
70
+
71
+ for pNum in range(len(annots['isValid'])):
72
+ if not annots['isValid'][pNum]:
73
+ continue
74
+
75
+ gt = annots['smpl_gt'][pNum]
76
+ betas = gt['betas'].flatten()[:10]
77
+ betas_list.append(torch.from_numpy(betas))
78
+ full_poses = torch.cat([torch.from_numpy(gt['global_orient'].flatten()), torch.from_numpy(gt['body_pose'].flatten())])
79
+ poses_list.append(full_poses)
80
+ transl_list.append(torch.from_numpy(gt['transl'].flatten()))
81
+
82
+ kid.append(annots['kid'][pNum])
83
+
84
+ if self.mode == 'eval':
85
+ occ_leval_list.append(int(annots['occlusion'][pNum]//10))
86
+
87
+ betas = torch.stack(betas_list)
88
+ poses = torch.stack(poses_list)
89
+ transl = torch.stack(transl_list)
90
+
91
+ raw_data={'img_path': img_path,
92
+ 'ds': 'agora',
93
+ 'pnum': len(betas),
94
+ 'betas': betas.float(),
95
+ 'poses': poses.float(),
96
+ 'transl': transl.float(),
97
+ 'kid': torch.tensor(kid),
98
+ 'cam_rot': cam_rot.float(),
99
+ 'cam_trans': cam_trans.float(),
100
+ 'cam_intrinsics':cam_intrinsics.float(),
101
+ '3d_valid': True,
102
+ 'age_valid': True,
103
+ 'detect_all_people':True
104
+ }
105
+
106
+ if self.mode == 'eval':
107
+ raw_data['occ_level'] = torch.tensor(occ_leval_list)
108
+
109
+ return raw_data
110
+
111
+
datasets/base.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import os
3
+ import numpy as np
4
+ from torch.utils.data.dataset import Dataset
5
+ from torchvision import transforms
6
+ from utils.visualization import tensor_to_BGR, vis_meshes_img, vis_boxes, vis_scale_img, pad_img, get_colors_rgb, vis_sat
7
+ from utils.transforms import unNormalize, to_zorder
8
+ from PIL import Image
9
+ import math
10
+ from tqdm import tqdm
11
+ import cv2
12
+ import torch
13
+ import copy
14
+ from math import radians,sin,cos
15
+ from utils import constants
16
+ from utils.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh
17
+ from utils.constants import smpl_24_flip, smpl_root_idx
18
+ from utils.map import gen_scale_map, build_z_map
19
+ from configs.paths import smpl_model_path
20
+ from models.human_models import SMPL_Layer, smpl_gendered
21
+
22
+
23
+ class BASE(Dataset):
24
+ def __init__(self, input_size = 1288, aug = True, mode = 'train',
25
+ human_type = 'smpl',
26
+ sat_cfg = None,
27
+ aug_cfg = None):
28
+ self.input_size = input_size
29
+ self.aug = aug
30
+ if mode not in ['train', 'eval', 'infer']:
31
+ raise NotImplementedError
32
+ if human_type not in ['smpl', 'no']:
33
+ raise NotImplementedError
34
+ self.mode = mode
35
+ self.human_type = human_type
36
+ assert sat_cfg is not None
37
+ self.use_sat = sat_cfg['use_sat']
38
+ self.sat_cfg = sat_cfg
39
+
40
+ if self.use_sat:
41
+ assert input_size % 56 == 0
42
+
43
+ if self.mode == 'train' and aug_cfg is None:
44
+ aug_cfg = {'rot_range': [-15, 15],
45
+ 'scale_range': [0.8, 1.8],
46
+ 'flip_ratio': 0.5,
47
+ 'crop_ratio': 0.}
48
+ self.aug_cfg = aug_cfg
49
+
50
+ if human_type == 'smpl':
51
+ self.poses_flip = smpl_24_flip
52
+ self.num_poses = 24
53
+ self.num_betas = 10
54
+ self.num_kpts = 45
55
+ self.human_model = smpl_gendered
56
+
57
+ self.vis_thresh = 4 # least num visible kpts for a valid individual
58
+
59
+ self.img_keys = ['img_path', 'ds',
60
+ 'pnum', 'img_size',
61
+ 'resize_rate', 'cam_intrinsics',
62
+ '3d_valid', 'detect_all_people',
63
+ 'scale_map', 'scale_map_pos', 'scale_map_hw']
64
+ self.human_keys = ['boxes', 'labels',
65
+ 'poses', 'betas',
66
+ 'transl', 'verts',
67
+ 'j3ds', 'j2ds', 'j2ds_mask',
68
+ 'depths', 'focals', 'genders']
69
+
70
+ z_depth = math.ceil(math.log2(self.input_size//28))
71
+ self.z_order_map, self.y_coords, self.x_coords = build_z_map(z_depth)
72
+
73
+
74
+
75
+ def get_raw_data(self, idx):
76
+ raise NotImplementedError
77
+
78
+ def get_aug_dict(self):
79
+ if self.aug:
80
+ rot = random.uniform(*self.aug_cfg['rot_range'])
81
+ flip = random.random() <= self.aug_cfg['flip_ratio']
82
+ scale = random.uniform(*self.aug_cfg['scale_range'])
83
+ crop = random.random() <= self.aug_cfg['crop_ratio']
84
+ else:
85
+ rot = 0.
86
+ flip = False
87
+ scale = 1.
88
+ crop = False
89
+
90
+ return {'rot':rot, 'flip':flip, 'scale':scale, 'crop': crop}
91
+
92
+ def process_img(self, img, meta_data, rot = 0., flip = False, scale = 1.0, crop = False):
93
+ # randomly crop (similar to scale)
94
+ if self.mode == 'train' and crop:
95
+
96
+ h, w = img.shape[:2]
97
+ if h < w :
98
+ clip_ratio = random.uniform(0.5, 0.9)
99
+ tgt_h, tgt_w = int(h*clip_ratio), int(w*clip_ratio)
100
+
101
+ img = img[:tgt_h,(w-tgt_w)//2:(w+tgt_w)//2,:].copy()
102
+ cam_intrinsics = meta_data['cam_intrinsics']
103
+ cam_intrinsics[:,0,2] -= (w-tgt_w)//2
104
+ meta_data.update({'cam_intrinsics': cam_intrinsics})
105
+
106
+
107
+ # resize
108
+ img_size = torch.tensor(img.shape[:2])
109
+ if img_size[1] >= img_size[0]:
110
+ resize_rate = self.input_size/img_size[1]
111
+ img = cv2.resize(img,dsize=(self.input_size,int(resize_rate*img_size[0])))
112
+ img_size = torch.tensor([int(resize_rate*img_size[0]),self.input_size])
113
+ else:
114
+ resize_rate = self.input_size/img_size[0]
115
+ img = cv2.resize(img,dsize=(int(resize_rate*img_size[1]),self.input_size))
116
+ img_size = torch.tensor([self.input_size,int(resize_rate*img_size[1])])
117
+ meta_data.update({'img_size': img_size, 'resize_rate': resize_rate})
118
+
119
+ # flip
120
+ if flip:
121
+ img = np.flip(img, axis = 1)
122
+ rot = -rot
123
+
124
+ # rot and scale
125
+ img_valid = np.full((img.shape[0], img.shape[1]), 255, dtype = np.uint8)
126
+ M = cv2.getRotationMatrix2D((int(img_size[1]/2),int(img_size[0]/2)), rot, scale)
127
+ img = cv2.warpAffine(img, M, dsize = (img.shape[1],img.shape[0]))
128
+ img_valid = cv2.warpAffine(img_valid, M, dsize = (img.shape[1],img.shape[0]))
129
+ meta_data.update({'img_valid': img_valid})
130
+
131
+ return img
132
+
133
+ def occlusion_aug(self, meta_data):
134
+ occ_boxes = []
135
+ imght, imgwidth = meta_data['img_size']
136
+ for bbox in box_cxcywh_to_xyxy(meta_data['boxes']):
137
+ bbox = bbox.clone()
138
+ bbox *= self.input_size
139
+ xmin, ymin = bbox[:2]
140
+ xmax, ymax = bbox[2:]
141
+
142
+ if random.random() <= 0.6:
143
+ counter = 0
144
+ while True:
145
+ # force to break if no suitable occlusion
146
+ if counter > 5:
147
+ synth_ymin, synth_h, synth_xmin, synth_w = 0, 0, 0, 0
148
+ break
149
+ counter += 1
150
+
151
+ area_min = 0.0
152
+ area_max = 0.3
153
+ synth_area = (random.random() * (area_max - area_min) + area_min) * (xmax - xmin) * (ymax - ymin)
154
+
155
+ ratio_min = 0.5
156
+ ratio_max = 1 / 0.5
157
+ synth_ratio = (random.random() * (ratio_max - ratio_min) + ratio_min)
158
+
159
+ synth_h = math.sqrt(synth_area * synth_ratio)
160
+ synth_w = math.sqrt(synth_area / synth_ratio)
161
+ synth_xmin = random.random() * ((xmax - xmin) - synth_w - 1) + xmin
162
+ synth_ymin = random.random() * ((ymax - ymin) - synth_h - 1) + ymin
163
+
164
+ if synth_xmin >= 0 and synth_ymin >= 0 and synth_xmin + synth_w < imgwidth and synth_ymin + synth_h < imght:
165
+ synth_xmin = int(synth_xmin)
166
+ synth_ymin = int(synth_ymin)
167
+ synth_w = int(synth_w)
168
+ synth_h = int(synth_h)
169
+ break
170
+ else:
171
+ synth_ymin, synth_h, synth_xmin, synth_w = 0, 0, 0, 0
172
+ occ_boxes.append((synth_ymin, synth_h, synth_xmin, synth_w))
173
+ return occ_boxes
174
+
175
+ def get_boxes(self, meta_data):
176
+ j2ds = meta_data['j2ds']
177
+ j2ds_mask = meta_data['j2ds_mask']
178
+ pnum = meta_data['pnum']
179
+
180
+ bboxes_list = []
181
+
182
+ for i in range(pnum):
183
+ kpts = j2ds[i].clone()
184
+ min_xy = kpts.min(dim = 0)[0]
185
+ max_xy = kpts.max(dim = 0)[0]
186
+ bbox_xyxy = torch.cat([min_xy, max_xy], dim = 0)
187
+ bboxes_list.append(bbox_xyxy)
188
+
189
+ imght, imgwidth = meta_data['img_size']
190
+ boxes = box_xyxy_to_cxcywh(torch.stack(bboxes_list)) / self.input_size
191
+ boxes[...,2:] *= 1.2
192
+ boxes = box_cxcywh_to_xyxy(boxes)
193
+ boxes[...,[0,2]] = boxes[...,[0,2]].clamp(min=0.01,max=(imgwidth-1)/self.input_size)
194
+ boxes[...,[1,3]] = boxes[...,[1,3]].clamp(min=0.01,max=(imght-1)/self.input_size)
195
+ boxes = box_xyxy_to_cxcywh(boxes)
196
+
197
+ meta_data.update({'boxes': boxes})
198
+
199
+ def process_cam(self, meta_data, rot = 0., flip = False, scale = 1.):
200
+ img_size = meta_data['img_size']
201
+ resize_rate = meta_data['resize_rate']
202
+ rot_aug_mat = meta_data['rot_aug_mat']
203
+ cam_intrinsics = meta_data['cam_intrinsics']
204
+ # cam_int
205
+ # resize
206
+ cam_intrinsics[:,0:2,2] *= resize_rate * scale
207
+ cam_intrinsics[:,[0,1],[0,1]] *= resize_rate * scale
208
+ cam_intrinsics[:,0,2] += (1-scale)*img_size[1]/2
209
+ cam_intrinsics[:,1,2] += (1-scale)*img_size[0]/2
210
+ # rotation
211
+ princpt = cam_intrinsics[:,0:2,2].clone()
212
+ princpt[...,0] -= img_size[1]/2
213
+ princpt[...,1] -= img_size[0]/2
214
+ princpt = torch.matmul(princpt,rot_aug_mat[:2,:2].transpose(-1,-2))
215
+ princpt[...,0] += img_size[1]/2
216
+ princpt[...,1] += img_size[0]/2
217
+ cam_intrinsics[:,0:2,2] = princpt
218
+ # flip
219
+ if flip:
220
+ cam_intrinsics[:,0,2] = img_size[1]-cam_intrinsics[:,0,2]
221
+ meta_data.update({'cam_intrinsics': cam_intrinsics})
222
+
223
+ #cam_ext
224
+ new_cam_rot = torch.matmul(rot_aug_mat.unsqueeze(0),meta_data['cam_rot'])
225
+ new_cam_trans = torch.matmul(meta_data['cam_trans'],rot_aug_mat.transpose(-1,-2))
226
+ meta_data.update({'cam_rot': new_cam_rot,'cam_trans':new_cam_trans})
227
+
228
+ def process_smpl(self, meta_data, rot = 0., flip = False, scale = 1.):
229
+ poses = meta_data['poses']
230
+ bs = poses.shape[0]
231
+ assert poses.ndim == 2
232
+ assert tuple(poses.shape) == (bs, self.num_poses*3)
233
+ # Merge rotation to smpl global_orient
234
+ global_orient = poses[:,:3].clone()
235
+ cam_rot = meta_data['cam_rot'].numpy()
236
+ for i in range(global_orient.shape[0]):
237
+ root_pose = global_orient[i].view(1, 3).numpy()
238
+ R = cam_rot[i].reshape(3,3)
239
+ root_pose, _ = cv2.Rodrigues(root_pose)
240
+ root_pose, _ = cv2.Rodrigues(np.dot(R, root_pose))
241
+ root_pose = torch.from_numpy(root_pose).flatten()
242
+ global_orient[i] = root_pose
243
+ poses[:,:3] = global_orient
244
+
245
+ # Flip smpl parameters
246
+ if flip:
247
+ poses = poses.reshape(bs, self.num_poses, 3)
248
+ poses = poses[:, self.poses_flip, :]
249
+ poses[..., 1:3] *= -1 # multiply -1 to y and z axis of axis-angle
250
+ poses = poses.reshape(bs, -1)
251
+
252
+ # Update all pose params
253
+ meta_data.update({'poses': poses})
254
+
255
+ # Get vertices and joints in cam_coords
256
+ with torch.no_grad():
257
+ smpl_kwargs = {'poses': meta_data['poses'], 'betas': meta_data['betas']}
258
+ if 'genders' in meta_data:
259
+ smpl_kwargs.update({'genders': meta_data['genders']})
260
+ verts, j3ds = self.human_model(**smpl_kwargs)
261
+
262
+ j3ds = j3ds[:, :self.num_kpts, :]
263
+ root = j3ds[:,smpl_root_idx,:].clone() # smpl root
264
+ # new translation in cam_coords
265
+ transl = torch.bmm((root+meta_data['transl']).reshape(-1,1,3),meta_data['cam_rot'].transpose(-1,-2)).reshape(-1,3)\
266
+ +meta_data['cam_trans']-root
267
+ if flip:
268
+ transl[...,0] = -transl[...,0]
269
+
270
+ meta_data.update({'transl': transl})
271
+
272
+ verts = verts + transl.reshape(-1,1,3)
273
+ j3ds = j3ds + transl.reshape(-1,1,3)
274
+ meta_data.update({'verts': verts, 'j3ds': j3ds})
275
+
276
+ def project_joints(self, meta_data):
277
+ j3ds = meta_data['j3ds']
278
+ cam_intrinsics = meta_data['cam_intrinsics']
279
+ j2ds_homo = torch.matmul(j3ds,cam_intrinsics.transpose(-1,-2))
280
+ j2ds = j2ds_homo[...,:2]/(j2ds_homo[...,2,None])
281
+
282
+ meta_data.update({'j3ds': j3ds, 'j2ds': j2ds})
283
+
284
+ def check_visibility(self, meta_data):
285
+ img_valid = meta_data['img_valid']
286
+ img_size = meta_data['img_size']
287
+
288
+ j2ds = meta_data['j2ds']
289
+ j2ds_mask = meta_data['j2ds_mask'] if 'j2ds_mask' in meta_data else torch.ones_like(j2ds, dtype=bool)
290
+
291
+
292
+ j2ds_vis = torch.from_numpy(img_valid[j2ds[...,1].int().clip(0,img_size[0]-1), j2ds[...,0].int().clip(0,img_size[1]-1)] > 0)
293
+ j2ds_vis &= (j2ds[...,1] >= 0) & (j2ds[...,1] < img_size[0])
294
+ j2ds_vis &= (j2ds[...,0] >= 0) & (j2ds[...,0] < img_size[1])
295
+
296
+ j2ds_invalid = ~j2ds_vis
297
+ j2ds_mask[j2ds_invalid] = False
298
+ meta_data.update({'j2ds_mask': j2ds_mask})
299
+
300
+ vis_cnt = j2ds_mask[...,0].sum(dim = -1) # num of visible joints per person
301
+ valid_msk = (vis_cnt >= self.vis_thresh)
302
+
303
+ pnum = valid_msk.sum().item()
304
+
305
+ if pnum == 0:
306
+ meta_data['pnum'] = pnum
307
+ return
308
+
309
+ if pnum < meta_data['pnum']:
310
+ meta_data['pnum'] = pnum
311
+ for key in self.human_keys:
312
+ if key in meta_data:
313
+ if isinstance(meta_data[key], list):
314
+ meta_data[key] = np.array(meta_data[key])[valid_msk].tolist()
315
+ else:
316
+ meta_data[key] = meta_data[key][valid_msk]
317
+ if 'cam_intrinsics' in meta_data and len(meta_data['cam_intrinsics']) > 1:
318
+ meta_data['cam_intrinsics'] = meta_data['cam_intrinsics'][valid_msk]
319
+
320
+ return
321
+
322
+
323
+ def process_data(self, img, raw_data, rot = 0., flip = False, scale = 1., crop = False):
324
+ meta_data = copy.deepcopy(raw_data)
325
+ # prepare rotation augmentation mat.
326
+ rot_aug_mat = torch.tensor([[cos(radians(-rot)), -sin(radians(-rot)), 0.],
327
+ [sin(radians(-rot)), cos(radians(-rot)), 0.],
328
+ [0., 0., 1.]])
329
+ meta_data.update({'rot_aug_mat': rot_aug_mat})
330
+
331
+ img = self.process_img(img, meta_data, rot, flip, scale, crop)
332
+
333
+ self.process_cam(meta_data, rot, flip, scale)
334
+ self.process_smpl(meta_data, rot, flip, scale)
335
+ self.project_joints(meta_data)
336
+ self.check_visibility(meta_data)
337
+ matcher_vis = meta_data['j2ds_mask'][:,:22,0].sum(dim = -1) # num of visible joints used in Hungarian Matcher
338
+ if meta_data['pnum'] == 0 or not torch.all(matcher_vis):
339
+ if self.mode == 'train':
340
+ meta_data['pnum'] = 0
341
+ return img, meta_data
342
+
343
+ j3ds = meta_data['j3ds']
344
+ depths = j3ds[:, smpl_root_idx, [2]].clone()
345
+ if len(meta_data['cam_intrinsics']) == 1:
346
+ focals = torch.full_like(depths, meta_data['cam_intrinsics'][0,0,0])
347
+ else:
348
+ focals = meta_data['cam_intrinsics'][:,0,0][:, None]
349
+ depths = torch.cat([depths, depths/focals],dim=-1)
350
+ meta_data.update({'depths': depths, 'focals': focals})
351
+
352
+ self.get_boxes(meta_data)
353
+
354
+ meta_data.update({'labels': torch.zeros(meta_data['pnum'], dtype=int)})
355
+
356
+
357
+ # VI. Occlusion augmentation
358
+ if self.aug:
359
+ occ_boxes = self.occlusion_aug(meta_data)
360
+ for (synth_ymin, synth_h, synth_xmin, synth_w) in occ_boxes:
361
+ img[synth_ymin:synth_ymin + synth_h, synth_xmin:synth_xmin + synth_w, :] = np.random.rand(synth_h, synth_w, 3) * 255
362
+
363
+ if self.use_sat:
364
+ # scale map
365
+ boxes = meta_data['boxes']
366
+ scales = boxes[:,2:].norm(p=2,dim=1)
367
+ v3ds = meta_data['verts']
368
+ depths_norm = meta_data['depths'][:,1]
369
+ cam_intrinsics = meta_data['cam_intrinsics']
370
+ sorted_idx = torch.argsort(depths_norm, descending=True)
371
+ map_size = (meta_data['img_size'] + 27)//28
372
+
373
+ scale_map = gen_scale_map(scales[sorted_idx], v3ds[sorted_idx],
374
+ faces = self.human_model.faces,
375
+ cam_intrinsics = cam_intrinsics[sorted_idx] if len(cam_intrinsics) > 1 else cam_intrinsics,
376
+ map_size = map_size,
377
+ patch_size = 28,
378
+ pad = True)
379
+ scale_map_z, _, pos_y, pos_x = to_zorder(scale_map,
380
+ z_order_map = self.z_order_map,
381
+ y_coords = self.y_coords,
382
+ x_coords = self.x_coords)
383
+ meta_data['scale_map'] = scale_map_z
384
+ meta_data['scale_map_pos'] = {'pos_y': pos_y, 'pos_x': pos_x}
385
+ meta_data['scale_map_hw'] = scale_map.shape[:2]
386
+
387
+
388
+
389
+ return img, meta_data
390
+
391
+ def __getitem__(self, index):
392
+
393
+ raw_data = self.get_raw_data(index)
394
+
395
+ # Load original image
396
+ ori_img = cv2.imread(raw_data['img_path'])
397
+ if raw_data['ds'] == 'bedlam' and 'closeup' in raw_data['img_path']:
398
+ ori_img = cv2.rotate(ori_img, cv2.ROTATE_90_CLOCKWISE)
399
+ img_size = torch.tensor(ori_img.shape[:2])
400
+ raw_data.update({'img_size': img_size})
401
+
402
+ if self.mode == 'train':
403
+ cnt = 0
404
+ while (True):
405
+ aug_dict = self.get_aug_dict()
406
+ img, meta_data = self.process_data(ori_img, raw_data, **aug_dict)
407
+ if meta_data['pnum'] > 0:
408
+ break
409
+ cnt+=1
410
+ if cnt >= 10:
411
+ aug_dict.update({'rot':0., 'scale':1., 'crop': False})
412
+ img, meta_data = self.process_data(ori_img, raw_data, **aug_dict)
413
+ if meta_data['pnum'] == 0:
414
+ print('skipping: ' + meta_data['img_path'])
415
+ return self.__getitem__(index + 1)
416
+
417
+ elif self.mode == 'eval':
418
+ assert not self.aug, f'No need to use augmentation when mode is {self.mode}!'
419
+ aug_dict = self.get_aug_dict()
420
+ img, meta_data = self.process_data(ori_img, raw_data, **aug_dict)
421
+
422
+ else:
423
+ assert not self.aug, f'No need to use augmentation when mode is {self.mode}!'
424
+ meta_data = raw_data
425
+ img = self.process_img(ori_img, meta_data)
426
+
427
+ # delete unwanted keys
428
+ if self.mode == 'train':
429
+ for key in list(meta_data.keys()):
430
+ if key not in self.img_keys and key not in self.human_keys:
431
+ del meta_data[key]
432
+
433
+
434
+ if self.aug:
435
+ array2tensor = transforms.Compose([
436
+ transforms.ColorJitter(0.2, 0.2, 0.2),
437
+ transforms.ToTensor(),
438
+ transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
439
+ ])
440
+ else:
441
+ array2tensor = transforms.Compose([
442
+ transforms.ToTensor(),
443
+ transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
444
+ ])
445
+
446
+
447
+ patch_size = 14
448
+ if self.use_sat:
449
+ patch_size = 56
450
+ pad_img = np.zeros((math.ceil(img.shape[0]/patch_size)*patch_size, math.ceil(img.shape[1]/patch_size)*patch_size, 3), dtype=img.dtype)
451
+ pad_img[:img.shape[0], :img.shape[1]] = img
452
+ assert max(pad_img.shape[:2]) == self.input_size
453
+ pad_img = Image.fromarray(pad_img[:,:,::-1].copy())
454
+ norm_img = array2tensor(pad_img)
455
+
456
+ if 'j2ds_mask' in meta_data:
457
+ meta_data['j2ds_mask'][:,:,:] = True
458
+
459
+ return norm_img, meta_data
460
+
461
+ def visualize(self, results_save_dir = None, vis_num = 100):
462
+ if results_save_dir is None:
463
+ results_save_dir = os.path.join('datasets_visualization',f'{self.ds_name}_{self.split}')
464
+ os.makedirs(results_save_dir, exist_ok=True)
465
+
466
+ vis_interval = len(self)//vis_num
467
+
468
+ for idx in tqdm(range(len(self))):
469
+ if idx % vis_interval != 0:
470
+ continue
471
+
472
+ norm_img, targets = self.__getitem__(idx)
473
+
474
+ ori_img = tensor_to_BGR(unNormalize(norm_img).cpu())
475
+ img_name = targets['img_path'].split('/')[-1].split('.')[-2]
476
+ pnum = targets['pnum']
477
+
478
+ if 'verts' in targets:
479
+ colors = get_colors_rgb(len(targets['verts']))
480
+ mesh_img = vis_meshes_img(img = ori_img.copy(),
481
+ verts = targets['verts'],
482
+ smpl_faces = self.human_model.faces,
483
+ cam_intrinsics = targets['cam_intrinsics'].cpu(),
484
+ colors=colors,
485
+ padding=False)
486
+ cv2.imwrite(os.path.join(results_save_dir,f'{idx}_{img_name}_mesh.jpg'), mesh_img)
487
+
488
+
489
+ if 'boxes' in targets:
490
+ gt_img = ori_img.copy()
491
+ boxes = box_cxcywh_to_xyxy(targets['boxes']) * self.input_size
492
+ for i, bbox in enumerate(boxes):
493
+ bbox = bbox.int().tolist()
494
+ cv2.rectangle(gt_img, (bbox[0], bbox[1]), (bbox[2], bbox[3]),
495
+ color=(0,0,255), thickness = 2 )
496
+
497
+ cv2.imwrite(os.path.join(results_save_dir,f'{idx}_{img_name}_boxes.jpg'), gt_img)
498
+
499
+ if 'scale_map' in targets:
500
+ gt_img = ori_img.copy()
501
+ flatten_map = targets['scale_map']
502
+ ys, xs = targets['scale_map_pos']['pos_y'], targets['scale_map_pos']['pos_x']
503
+ h, w = targets['scale_map_hw']
504
+ scale_map = torch.zeros((h,w,2))
505
+ scale_map[ys,xs] = flatten_map
506
+ img = vis_scale_img(gt_img, scale_map, patch_size=28)
507
+
508
+ cv2.imwrite(os.path.join(results_save_dir,f'{idx}_{img_name}_scales.jpg'), img)
509
+
510
+
511
+ # if 'j2ds' in targets:
512
+ # gt_img = ori_img.copy()
513
+ # j2ds = targets['j2ds']
514
+ # j2ds_mask = targets['j2ds_mask']
515
+ # for kpts, valids in zip(j2ds, j2ds_mask):
516
+ # for kpt, valid in zip(kpts, valids):
517
+ # if not valid.all():
518
+ # continue
519
+ # kpt_int = kpt.numpy().astype(int)
520
+ # cv2.circle(gt_img, kpt_int, 2, (0, 0, 255), -1)
521
+ # cv2.imwrite(os.path.join(results_save_dir,f'{idx}_{img_name}_joints.png'), np.hstack([ori_img, gt_img]))
datasets/bedlam.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.utils.data.dataset import Dataset
4
+ import os
5
+ from configs.paths import dataset_root
6
+ import copy
7
+ from tqdm import tqdm
8
+ from .base import BASE
9
+
10
+ class BEDLAM(BASE):
11
+ def __init__(self, split='train_6fps',**kwargs):
12
+ super(BEDLAM, self).__init__(**kwargs)
13
+ assert split in ['train_1fps','train_3fps','train_6fps','validation_6fps']
14
+ assert not self.kid_offset
15
+
16
+ self.ds_name = 'bedlam'
17
+ self.dataset_path = os.path.join(dataset_root,'bedlam')
18
+ annots_path = os.path.join(self.dataset_path,f'bedlam_smpl_{split}.npz')
19
+ self.annots = np.load(annots_path, allow_pickle=True)['annots'][()]
20
+ self.img_names = list(self.annots.keys())
21
+ self.split = 'train' if 'train' in split else 'validation'
22
+
23
+ def __len__(self):
24
+ return len(self.img_names)
25
+
26
+ def cnt_instances(self):
27
+ ins_cnt = 0
28
+ for idx in tqdm(range(len(self))):
29
+ img_id = idx
30
+ img_name = self.img_names[img_id]
31
+ # ins_cnt += len(self.annots[img_name]['isValid'])
32
+ ins_cnt += len(self.annots[img_name]['shape'])
33
+ # tqdm.write(str(ins_cnt))
34
+
35
+ print(f'TOTAL: {ins_cnt}')
36
+
37
+ def get_raw_data(self, idx):
38
+
39
+ img_id = idx%len(self.img_names)
40
+ img_name = self.img_names[img_id]
41
+
42
+ annots = copy.deepcopy(self.annots[img_name])
43
+ img_path = os.path.join(self.dataset_path,self.split,img_name)
44
+
45
+ cam_intrinsics = torch.from_numpy(annots['cam_int']).unsqueeze(0)
46
+ cam_rot = torch.from_numpy(np.stack(annots['cam_rot']))
47
+ cam_trans = torch.from_numpy(np.stack(annots['cam_trans']))
48
+
49
+ betas = torch.from_numpy(np.stack(annots['shape']))
50
+ poses = torch.from_numpy(np.stack(annots['pose_world']))
51
+ transl = torch.from_numpy(np.stack(annots['trans_world']))
52
+
53
+ raw_data={'img_path': img_path,
54
+ 'ds': 'bedlam',
55
+ 'pnum': len(betas),
56
+ 'betas': betas.float(),
57
+ 'poses': poses.float(),
58
+ 'transl': transl.float(),
59
+ 'cam_rot': cam_rot.float(),
60
+ 'cam_trans': cam_trans.float(),
61
+ 'cam_intrinsics':cam_intrinsics.float(),
62
+ '3d_valid': True,
63
+ 'age_valid': False,
64
+ 'detect_all_people':True
65
+ }
66
+
67
+ if self.mode == 'eval':
68
+ raw_data['occ_level'] = torch.zeros(len(betas),dtype=int)
69
+
70
+ return raw_data
71
+
72
+
datasets/common.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.utils.data.dataset import Dataset
4
+ import os
5
+ from configs.paths import dataset_root
6
+ import copy
7
+ from .base import BASE
8
+
9
+ # dataset for inference
10
+ class COMMON(BASE):
11
+ def __init__(self, img_folder, **kwargs):
12
+ super(COMMON, self).__init__(**kwargs)
13
+ self.dataset_path = img_folder
14
+ self.img_names = sorted([img_name\
15
+ for img_name\
16
+ in os.listdir(self.dataset_path)\
17
+ if img_name.endswith('.png') or img_name.endswith('.jpg') or img_name.endswith('.jpeg')])
18
+ assert self.mode == 'infer'
19
+
20
+ def __len__(self):
21
+ return len(self.img_names)
22
+
23
+ def get_raw_data(self, idx):
24
+ img_id=idx%len(self.img_names)
25
+ img_name=self.img_names[img_id]
26
+ img_path=os.path.join(self.dataset_path,img_name)
27
+ raw_data={'img_path': img_path,
28
+ 'img_name': img_name,
29
+ 'ds': 'common'
30
+ }
31
+
32
+ return raw_data
33
+
34
+
datasets/multiple_datasets.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from torch.utils.data.dataset import Dataset
3
+ import numpy as np
4
+ from .agora import AGORA
5
+ from .bedlam import BEDLAM
6
+
7
+ datasets_dict = {'bedlam': BEDLAM, 'agora': AGORA}
8
+
9
+ class MultipleDatasets(Dataset):
10
+ def __init__(self, datasets_used, datasets_split = None, make_same_len = False, **kwargs):
11
+ if datasets_split is None:
12
+ self.dbs = [datasets_dict[ds](**kwargs) for ds in datasets_used]
13
+ else:
14
+ self.dbs = [datasets_dict[ds](split, **kwargs) for ds, split in zip(datasets_used, datasets_split)]
15
+
16
+ self.db_num = len(self.dbs)
17
+ self.max_db_data_num = max([len(db) for db in self.dbs])
18
+ self.db_len_cumsum = np.cumsum([len(db) for db in self.dbs])
19
+ self.make_same_len = make_same_len
20
+ self.human_model = self.dbs[0].human_model
21
+
22
+ def __len__(self):
23
+ # all dbs have the same length
24
+ if self.make_same_len:
25
+ return self.max_db_data_num * self.db_num
26
+ # each db has different length
27
+ else:
28
+ return sum([len(db) for db in self.dbs])
29
+
30
+ def __getitem__(self, index):
31
+ if self.make_same_len:
32
+ db_idx = index // self.max_db_data_num
33
+ data_idx = index % self.max_db_data_num
34
+ if data_idx >= len(self.dbs[db_idx]) * (self.max_db_data_num // len(self.dbs[db_idx])): # last batch: random sampling
35
+ data_idx = random.randint(0,len(self.dbs[db_idx])-1)
36
+ else: # before last batch: use modular
37
+ data_idx = data_idx % len(self.dbs[db_idx])
38
+ else:
39
+ for i in range(self.db_num):
40
+ if index < self.db_len_cumsum[i]:
41
+ db_idx = i
42
+ break
43
+ if db_idx == 0:
44
+ data_idx = index
45
+ else:
46
+ data_idx = index - self.db_len_cumsum[db_idx-1]
47
+
48
+ norm_img, meta_data = self.dbs[db_idx][data_idx]
49
+ return norm_img, meta_data
demo/img0.png ADDED

Git LFS Details

  • SHA256: 1473b45a1c82b64d90f01ad9f43db9719cc3858bbf3fe0dc36f0e8bee717e3fc
  • Pointer size: 132 Bytes
  • Size of remote file: 6.34 MB
demo/img1.jpeg ADDED
demo/img2.jpg ADDED

Git LFS Details

  • SHA256: d1be2c06e55a514a47d2fd6880f9cb702196b08e220997ae8571399efb7d7ab7
  • Pointer size: 132 Bytes
  • Size of remote file: 2.39 MB
docs/fix_chumpy.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You may need to modify `chumpy` package to avoid errors.
2
+
3
+ * Comment line 11 in `${Your_Conda_Environment}/lib/python3.11/site-packages/chumpy/__init__.py`:
4
+ ```
5
+ from .ch import *
6
+ from .logic import *
7
+
8
+ from .optimization import minimize
9
+ from . import extras
10
+ from . import testing
11
+ from .version import version as __version__
12
+
13
+ from .version import version as __version__
14
+
15
+ # from numpy import bool, int, float, complex, object, unicode, str, nan, inf
16
+ ```
17
+ * Add *"inspect.getargspec = inspect.getfullargspec"* in `${Your_Conda_Environment}/lib/python3.11/site-packages/chumpy/ch.py` (line 25). Now it should look like:
18
+ ```
19
+ #!/usr/bin/env python
20
+ # encoding: utf-8
21
+ """
22
+ Author(s): Matthew Loper
23
+
24
+ See LICENCE.txt for licensing and contact information.
25
+ """
26
+
27
+
28
+ __all__ = ['Ch', 'depends_on', 'MatVecMult', 'ChHandle', 'ChLambda']
29
+
30
+ import os, sys, time
31
+ import inspect
32
+ import scipy.sparse as sp
33
+ import numpy as np
34
+ import numbers
35
+ import weakref
36
+ import copy as external_copy
37
+ from functools import wraps
38
+ from scipy.sparse.linalg.interface import LinearOperator
39
+ from .utils import row, col, timer, convert_inputs_to_sparse_if_necessary
40
+ import collections
41
+ from copy import deepcopy
42
+ from functools import reduce
43
+ inspect.getargspec = inspect.getfullargspec
44
+ ```
engines/__init__.py ADDED
File without changes
engines/engine.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from accelerate import Accelerator
2
+ from tqdm.auto import tqdm
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ from datasets.multiple_datasets import MultipleDatasets, datasets_dict
6
+ from datasets.common import COMMON
7
+ from transformers import get_scheduler
8
+ from safetensors.torch import load_file
9
+ import os
10
+ import re
11
+ import time
12
+ import datetime
13
+ from models import build_sat_model
14
+ from .funcs.eval_funcs import *
15
+ from .funcs.infer_funcs import inference
16
+ from utils import misc
17
+ from utils.misc import get_world_size
18
+ import torch.multiprocessing
19
+ import numpy as np
20
+
21
+
22
+ class Engine():
23
+ def __init__(self, args, mode='train'):
24
+ self.exp_name = args.exp_name
25
+ self.mode = mode
26
+ assert mode in ['train','eval','infer']
27
+ self.conf_thresh = args.conf_thresh
28
+ self.eval_func_maps = {'agora_validation': evaluate_agora,
29
+ 'bedlam_validation_6fps': evaluate_agora,
30
+ 'agora_test': test_agora}
31
+ self.inference_func = inference
32
+
33
+ if self.mode == 'train':
34
+ self.output_dir = os.path.join('./outputs')
35
+ self.log_dir = os.path.join(self.output_dir,'logs')
36
+ self.ckpt_dir = os.path.join(self.output_dir,'ckpts')
37
+ self.distributed_eval = args.distributed_eval
38
+ self.eval_vis_num = args.eval_vis_num
39
+ elif self.mode == 'eval':
40
+ self.output_dir = os.path.join('./results')
41
+ self.distributed_eval = args.distributed_eval
42
+ self.eval_vis_num = args.eval_vis_num
43
+ elif self.mode == 'infer':
44
+ output_dir = getattr(args, 'output_dir', None)
45
+ if output_dir is not None:
46
+ self.output_dir = output_dir
47
+ else:
48
+ now = datetime.datetime.now()
49
+ timestamp = now.strftime("%Y%m%d_%H%M%S")
50
+ self.output_dir = os.path.join('./results',f'{self.exp_name}_infer_{timestamp}')
51
+ self.distributed_infer = args.distributed_infer
52
+
53
+ self.prepare_accelerator()
54
+ self.prepare_models(args)
55
+ self.prepare_datas(args)
56
+ if self.mode == 'train':
57
+ self.prepare_training(args)
58
+
59
+ total_cnt = sum(p.numel() for p in self.model.parameters())
60
+ trainable_cnt = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
61
+ self.accelerator.print(f'Initialization finished.\n{trainable_cnt} trainable parameters({total_cnt} total).')
62
+
63
+ def prepare_accelerator(self):
64
+ if self.mode == 'train':
65
+ self.accelerator = Accelerator(
66
+ log_with="tensorboard",
67
+ project_dir=os.path.join(self.log_dir)
68
+ )
69
+ if self.accelerator.is_main_process:
70
+ os.makedirs(self.log_dir, exist_ok=True)
71
+ os.makedirs(os.path.join(self.ckpt_dir,self.exp_name),exist_ok=True)
72
+ self.accelerator.init_trackers(self.exp_name)
73
+ else:
74
+ self.accelerator = Accelerator()
75
+ if self.accelerator.is_main_process:
76
+ os.makedirs(self.output_dir, exist_ok=True)
77
+
78
+ def prepare_models(self, args):
79
+ # load model and criterion
80
+ self.accelerator.print('Preparing models...')
81
+ self.unwrapped_model, self.criterion = build_sat_model(args, set_criterion = (self.mode == 'train'))
82
+ if self.criterion is not None:
83
+ self.weight_dict = self.criterion.weight_dict
84
+ # load weights
85
+ if args.pretrain:
86
+ self.accelerator.print(f'Loading pretrained weights: {args.pretrain_path}')
87
+ state_dict = torch.load(args.pretrain_path)
88
+ self.unwrapped_model.load_state_dict(state_dict,strict=False)
89
+
90
+ # to gpu
91
+ self.model = self.accelerator.prepare(self.unwrapped_model)
92
+
93
+ def prepare_datas(self, args):
94
+ # load dataset and dataloader
95
+ if self.mode == 'train':
96
+ self.accelerator.print('Loading training datasets:\n',
97
+ [f'{d}_{s}' for d,s in zip(args.train_datasets_used, args.train_datasets_split)])
98
+ self.train_batch_size = args.train_batch_size
99
+ train_dataset = MultipleDatasets(args.train_datasets_used, args.train_datasets_split,
100
+ make_same_len=False, input_size=args.input_size, aug=True,
101
+ mode = 'train', sat_cfg=args.sat_cfg,
102
+ aug_cfg=args.aug_cfg)
103
+ self.train_dataloader = DataLoader(dataset=train_dataset, batch_size=self.train_batch_size,
104
+ shuffle=True,collate_fn=misc.collate_fn,
105
+ num_workers=args.train_num_workers,pin_memory=True)
106
+ self.train_dataloader = self.accelerator.prepare(self.train_dataloader)
107
+
108
+ if self.mode != 'infer':
109
+ self.accelerator.print('Loading evaluation datasets:',
110
+ [f'{d}_{s}' for d,s in zip(args.eval_datasets_used, args.eval_datasets_split)])
111
+ self.eval_batch_size = args.eval_batch_size
112
+ eval_ds = {f'{ds}_{split}': datasets_dict[ds](split = split,
113
+ mode = 'eval',
114
+ input_size = args.input_size,
115
+ aug = False,
116
+ sat_cfg=args.sat_cfg)\
117
+ for (ds, split) in zip(args.eval_datasets_used, args.eval_datasets_split)}
118
+ self.eval_dataloaders = {k: DataLoader(dataset=v, batch_size=self.eval_batch_size,
119
+ shuffle=False,collate_fn=misc.collate_fn,
120
+ num_workers=args.eval_num_workers,pin_memory=True)\
121
+ for (k,v) in eval_ds.items()}
122
+ if self.distributed_eval:
123
+ for (k,v) in self.eval_dataloaders.items():
124
+ self.eval_dataloaders.update({k: self.accelerator.prepare(v)})
125
+
126
+ else:
127
+ img_folder = args.input_dir
128
+ self.accelerator.print(f'Loading inference images from {img_folder}')
129
+ self.infer_batch_size = args.infer_batch_size
130
+ infer_ds = COMMON(img_folder = img_folder, input_size=args.input_size,aug=False,
131
+ mode = 'infer', sat_cfg=args.sat_cfg)
132
+ self.infer_dataloader = DataLoader(dataset=infer_ds, batch_size=self.infer_batch_size,
133
+ shuffle=False,collate_fn=misc.collate_fn,
134
+ num_workers=args.infer_num_workers,pin_memory=True)
135
+
136
+ if self.distributed_infer:
137
+ self.infer_dataloader = self.accelerator.prepare(self.infer_dataloader)
138
+
139
+ def prepare_training(self, args):
140
+ self.start_epoch = 0
141
+ self.num_epochs = args.num_epochs
142
+ self.global_step = 0
143
+ if hasattr(args, 'sat_gt_epoch'):
144
+ self.sat_gt_epoch = args.sat_gt_epoch
145
+ self.accelerator.print(f'Use GT for the first {self.sat_gt_epoch} epoch(s)...')
146
+ else:
147
+ self.sat_gt_epoch = -1
148
+ self.save_and_eval_epoch = args.save_and_eval_epoch
149
+ self.least_eval_epoch = args.least_eval_epoch
150
+
151
+ self.detach_j3ds = args.detach_j3ds
152
+
153
+ self.accelerator.print('Preparing optimizer and lr_scheduler...')
154
+ param_dicts = [
155
+ {
156
+ "params":
157
+ [p for n, p in self.unwrapped_model.named_parameters()
158
+ if not misc.match_name_keywords(n, args.lr_encoder_names) and p.requires_grad],
159
+ "lr": args.lr,
160
+ },
161
+ {
162
+ "params":
163
+ [p for n, p in self.unwrapped_model.named_parameters()
164
+ if misc.match_name_keywords(n, args.lr_encoder_names) and p.requires_grad],
165
+ "lr": args.lr_encoder,
166
+ }
167
+ ]
168
+
169
+ # optimizer
170
+ if args.optimizer == 'adamw':
171
+ self.optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
172
+ weight_decay=args.weight_decay)
173
+ else:
174
+ raise NotImplementedError
175
+
176
+ # lr_scheduler
177
+ if args.lr_scheduler == 'cosine':
178
+ self.lr_scheduler = get_scheduler(name="cosine", optimizer=self.optimizer,
179
+ num_warmup_steps=args.num_warmup_steps,
180
+ num_training_steps=get_world_size() * self.num_epochs * len(self.train_dataloader))
181
+ elif args.lr_scheduler == 'multistep':
182
+ self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, args.milestones, gamma=args.gamma)
183
+ else:
184
+ raise NotImplementedError
185
+
186
+ self.optimizer, self.lr_scheduler = self.accelerator.prepare(self.optimizer, self.lr_scheduler)
187
+
188
+ # resume
189
+ if args.resume: #load model, optimizer, lr_scheduler and random_state
190
+ if hasattr(args, 'ckpt_epoch'):
191
+ self.load_ckpt(args.ckpt_epoch,args.ckpt_step)
192
+ else:
193
+ self.accelerator.print('Auto resume from latest ckpt...')
194
+ epoch, step = -1, -1
195
+ pattern = re.compile(r'epoch_(\d+)_step_(\d+)')
196
+ for folder_name in os.listdir(os.path.join(self.output_dir,'ckpts',self.exp_name)):
197
+ match = pattern.match(folder_name)
198
+ if match:
199
+ i, j = int(match.group(1)), int(match.group(2))
200
+ if i > epoch:
201
+ epoch, step = i, j
202
+ if epoch >= 0:
203
+ self.load_ckpt(epoch, step)
204
+ else:
205
+ self.accelerator.print('No existing ckpts! Train from scratch.')
206
+
207
+ def load_ckpt(self, epoch, step):
208
+ self.accelerator.print(f'Loading checkpoint: epoch_{epoch}_step_{step}')
209
+ ckpts_save_path = os.path.join(self.output_dir,'ckpts',self.exp_name, f'epoch_{epoch}_step_{step}')
210
+ self.start_epoch = epoch + 1
211
+ self.global_step = step + 1
212
+ self.accelerator.load_state(ckpts_save_path)
213
+
214
+ def train(self):
215
+ # torch.autograd.set_detect_anomaly(True)
216
+ self.accelerator.print('Start training!')
217
+ for epoch in range(self.start_epoch, self.num_epochs):
218
+ torch.cuda.empty_cache()
219
+ progress_bar = tqdm(total=len(self.train_dataloader), disable=not self.accelerator.is_local_main_process)
220
+ progress_bar.set_description(f"Epoch {epoch}")
221
+
222
+ self.model.train()
223
+ self.criterion.train()
224
+
225
+ sat_use_gt = (epoch < self.sat_gt_epoch)
226
+
227
+ for step, (samples,targets) in enumerate(self.train_dataloader):
228
+
229
+ outputs = self.model(samples, targets, sat_use_gt = sat_use_gt, detach_j3ds = self.detach_j3ds)
230
+ loss_dict = self.criterion(outputs, targets)
231
+
232
+ loss = sum(loss_dict[k] * self.weight_dict[k] for k in loss_dict.keys())
233
+
234
+
235
+ self.accelerator.backward(loss)
236
+
237
+ if self.accelerator.sync_gradients:
238
+ self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
239
+
240
+ self.optimizer.step()
241
+
242
+ self.lr_scheduler.step()
243
+ self.optimizer.zero_grad()
244
+
245
+ reduced_dict = self.accelerator.reduce(loss_dict,reduction='mean')
246
+ simplified_logs = {k: v.item() for k, v in reduced_dict.items() if '.' not in k}
247
+
248
+ # logs.update({"lr": self.lr_scheduler.get_last_lr()[0], "step": self.global_step})
249
+ if self.accelerator.is_main_process:
250
+ tqdm.write(f'[{epoch}-{step+1}/{len(self.train_dataloader)}]: ' + str(simplified_logs))
251
+
252
+ if step % 10 == 0:
253
+ self.accelerator.log({('train/'+k):v for k,v in simplified_logs.items()},
254
+ step=self.global_step)
255
+
256
+ progress_bar.update(1)
257
+ progress_bar.set_postfix(**{"lr": self.lr_scheduler.get_last_lr()[0], "step": self.global_step})
258
+
259
+ self.global_step += 1
260
+ self.accelerator.wait_for_everyone()
261
+
262
+ # self.lr_scheduler.step()
263
+
264
+ if epoch % self.save_and_eval_epoch == 0 or epoch == self.num_epochs-1:
265
+ self.save_and_eval(epoch, save_ckpt=True)
266
+
267
+ self.accelerator.end_training()
268
+
269
+ def eval(self, results_save_path = None, epoch = -1):
270
+ if results_save_path is None:
271
+ results_save_path = os.path.join(self.output_dir,self.exp_name,'evaluation')
272
+ # preparing
273
+ self.model.eval()
274
+ unwrapped_model = self.unwrapped_model # self.accelerator.unwrap_model(self.model)
275
+ if self.accelerator.is_main_process:
276
+ os.makedirs(results_save_path,exist_ok=True)
277
+ # evaluate
278
+ for i, (key, eval_dataloader) in enumerate(self.eval_dataloaders.items()):
279
+ assert key in self.eval_func_maps
280
+ img_cnt = len(eval_dataloader) * self.eval_batch_size
281
+ if self.distributed_eval:
282
+ img_cnt *= self.accelerator.num_processes
283
+ self.accelerator.print(f'Evaluate on {key}: {img_cnt} images')
284
+ self.accelerator.print('Using following threshold(s): ', self.conf_thresh)
285
+ conf_thresh = self.conf_thresh if 'agora' in key or 'bedlam' in key else [0.2]
286
+ for thresh in conf_thresh:
287
+ if self.accelerator.is_main_process or self.distributed_eval:
288
+ error_dict = self.eval_func_maps[key](model = unwrapped_model,
289
+ eval_dataloader = eval_dataloader,
290
+ conf_thresh = thresh,
291
+ vis_step = img_cnt // self.eval_vis_num,
292
+ results_save_path = os.path.join(results_save_path,key,f'thresh_{thresh}'),
293
+ distributed = self.distributed_eval,
294
+ accelerator = self.accelerator,
295
+ vis=True)
296
+ if isinstance(error_dict,dict) and self.mode == 'train':
297
+ log_dict = flatten_dict(error_dict)
298
+ self.accelerator.log({(f'{key}_thresh_{thresh}/'+k):v for k,v in log_dict.items()}, step=epoch)
299
+
300
+ self.accelerator.print(f'thresh_{thresh}: ',error_dict)
301
+ self.accelerator.wait_for_everyone()
302
+
303
+ def save_and_eval(self, epoch, save_ckpt=False):
304
+ torch.cuda.empty_cache()
305
+ # save current state and model
306
+ if self.accelerator.is_main_process and save_ckpt:
307
+ ckpts_save_path = os.path.join(self.output_dir,'ckpts',self.exp_name, f'epoch_{epoch}_step_{self.global_step-1}')
308
+ os.makedirs(ckpts_save_path,exist_ok=True)
309
+ self.accelerator.save_state(ckpts_save_path, safe_serialization=False)
310
+ self.accelerator.wait_for_everyone()
311
+
312
+ if epoch < self.least_eval_epoch:
313
+ return
314
+ results_save_path = os.path.join(self.output_dir,'results',self.exp_name, f'epoch_{epoch}_step_{self.global_step-1}')
315
+ self.eval(results_save_path, epoch=epoch)
316
+
317
+ def infer(self):
318
+ self.model.eval()
319
+ # unwrapped_model = self.accelerator.unwrap_model(self.model)
320
+ unwrapped_model = self.unwrapped_model
321
+
322
+ results_save_path = self.output_dir
323
+ if self.accelerator.is_main_process:
324
+ os.makedirs(results_save_path,exist_ok=True)
325
+
326
+ self.accelerator.print('Using following threshold(s): ', self.conf_thresh)
327
+ for thresh in self.conf_thresh:
328
+ if self.accelerator.is_main_process or self.distributed_infer:
329
+ self.inference_func(model = unwrapped_model,
330
+ infer_dataloader = self.infer_dataloader,
331
+ conf_thresh = thresh,
332
+ results_save_path = os.path.join(results_save_path,f'thresh_{thresh}'),
333
+ distributed = self.distributed_infer,
334
+ accelerator = self.accelerator)
335
+ self.accelerator.wait_for_everyone()
336
+
337
+
338
+ def flatten_dict(d, parent_key='', sep='-'):
339
+ items = []
340
+ for k, v in d.items():
341
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
342
+ if isinstance(v, dict):
343
+ items.extend(flatten_dict(v, new_key, sep=sep).items())
344
+ else:
345
+ items.append((new_key, v))
346
+ return dict(items)
347
+
engines/funcs/__init__.py ADDED
File without changes
engines/funcs/eval_funcs.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm.auto import tqdm
3
+ import torch
4
+ import numpy as np
5
+ from utils.evaluation import cal_3d_position_error, match_2d_greedy, get_matching_dict, compute_prf1, vectorize_distance, calculate_iou
6
+ from utils.transforms import pelvis_align, root_align, unNormalize
7
+ from utils.visualization import tensor_to_BGR, pad_img
8
+ from utils.visualization import vis_meshes_img, vis_boxes, vis_sat, vis_scale_img, get_colors_rgb
9
+ from utils.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh
10
+ from utils.constants import human36_eval_joint, J24_TO_H36M, H36M_TO_MPII
11
+ import time
12
+ import datetime
13
+ import scipy.io as sio
14
+ import cv2
15
+ import zipfile
16
+ import pickle
17
+
18
+ # for agora evaluation
19
+ def select_and_align(smpl_joints, smpl_verts, body_verts_ind):
20
+ joints = smpl_joints[:24, :]
21
+ verts = smpl_verts[body_verts_ind, :]
22
+ assert len(verts.shape) == 2
23
+ verts = pelvis_align(joints, verts)
24
+ joints = pelvis_align(joints)
25
+ return joints, verts
26
+
27
+
28
+ # Modified from agora_evaluation
29
+ def evaluate_agora(model, eval_dataloader, conf_thresh,
30
+ vis = True, vis_step = 40, results_save_path = None,
31
+ distributed = False, accelerator = None):
32
+ assert results_save_path is not None
33
+ assert accelerator is not None
34
+ num_processes = accelerator.num_processes
35
+
36
+ has_kid = ('train' in eval_dataloader.dataset.split and eval_dataloader.dataset.ds_name == 'agora')
37
+
38
+ os.makedirs(results_save_path,exist_ok=True)
39
+ if vis:
40
+ imgs_save_dir = os.path.join(results_save_path, 'imgs')
41
+ os.makedirs(imgs_save_dir, exist_ok = True)
42
+
43
+ step = 0
44
+ total_miss_count = 0
45
+ total_count = 0
46
+ total_fp = 0
47
+ mve, mpjpe = [0.], [0.]
48
+
49
+ if has_kid:
50
+ kid_total_miss_count = 0
51
+ kid_total_count = 0
52
+ kid_mve, kid_mpjpe = [0.], [0.]
53
+
54
+ cur_device = next(model.parameters()).device
55
+ smpl_layer = model.human_model
56
+ body_verts_ind = smpl_layer.body_vertex_idx
57
+
58
+ progress_bar = tqdm(total=len(eval_dataloader), disable=not accelerator.is_local_main_process)
59
+ progress_bar.set_description('evaluate')
60
+ for itr, (samples, targets) in enumerate(eval_dataloader):
61
+ samples=[sample.to(device = cur_device, non_blocking = True) for sample in samples]
62
+ with torch.no_grad():
63
+ outputs = model(samples, targets)
64
+ bs = len(targets)
65
+ for idx in range(bs):
66
+ #gt
67
+ gt_j2ds = targets[idx]['j2ds'].cpu().numpy()[:,:24,:]
68
+ gt_j3ds = targets[idx]['j3ds'].cpu().numpy()[:,:24,:]
69
+ gt_verts = targets[idx]['verts'].cpu().numpy()
70
+
71
+ #pred
72
+ select_queries_idx = torch.where(outputs['pred_confs'][idx] > conf_thresh)[0]
73
+ pred_j2ds = outputs['pred_j2ds'][idx][select_queries_idx].detach().cpu().numpy()[:,:24,:]
74
+ pred_j3ds = outputs['pred_j3ds'][idx][select_queries_idx].detach().cpu().numpy()[:,:24,:]
75
+ pred_verts = outputs['pred_verts'][idx][select_queries_idx].detach().cpu().numpy()
76
+
77
+
78
+ matched_verts_idx = []
79
+ assert len(gt_j2ds.shape) == 3 and len(pred_j2ds.shape) == 3
80
+ #matching
81
+ greedy_match = match_2d_greedy(pred_j2ds, gt_j2ds) # tuples are (idx_pred_kps, idx_gt_kps)
82
+ matchDict, falsePositive_count = get_matching_dict(greedy_match)
83
+
84
+ #align with matching result
85
+ gt_verts_list, pred_verts_list, gt_joints_list, pred_joints_list = [], [], [], []
86
+ gtIdxs = np.arange(len(gt_j3ds))
87
+ miss_flag = []
88
+ for gtIdx in gtIdxs:
89
+ gt_verts_list.append(gt_verts[gtIdx])
90
+ gt_joints_list.append(gt_j3ds[gtIdx])
91
+ if matchDict[str(gtIdx)] == 'miss' or matchDict[str(
92
+ gtIdx)] == 'invalid':
93
+ miss_flag.append(1)
94
+ pred_verts_list.append([])
95
+ pred_joints_list.append([])
96
+ else:
97
+ miss_flag.append(0)
98
+ pred_joints_list.append(pred_j3ds[matchDict[str(gtIdx)]])
99
+ pred_verts_list.append(pred_verts[matchDict[str(gtIdx)]])
100
+ matched_verts_idx.append(matchDict[str(gtIdx)])
101
+
102
+ if has_kid:
103
+ gt_kid_list = targets[idx]['kid']
104
+
105
+ #calculating 3d errors
106
+ for i, (gt3d, pred) in enumerate(zip(gt_joints_list, pred_joints_list)):
107
+ total_count += 1
108
+ if has_kid and gt_kid_list[i]:
109
+ kid_total_count += 1
110
+
111
+ # Get corresponding ground truth and predicted 3d joints and verts
112
+ if miss_flag[i] == 1:
113
+ total_miss_count += 1
114
+ if has_kid and gt_kid_list[i]:
115
+ kid_total_miss_count += 1
116
+ continue
117
+
118
+ gt3d = gt3d.reshape(-1, 3)
119
+ pred3d = pred.reshape(-1, 3)
120
+ gt3d_verts = gt_verts_list[i].reshape(-1, 3)
121
+ pred3d_verts = pred_verts_list[i].reshape(-1, 3)
122
+
123
+ gt3d, gt3d_verts = select_and_align(gt3d, gt3d_verts, body_verts_ind)
124
+ pred3d, pred3d_verts = select_and_align(pred3d, pred3d_verts, body_verts_ind)
125
+
126
+ #joints
127
+ error_j, pa_error_j = cal_3d_position_error(pred3d, gt3d)
128
+ mpjpe.append(error_j)
129
+ if has_kid and gt_kid_list[i]:
130
+ kid_mpjpe.append(error_j)
131
+ #vertices
132
+ error_v,pa_error_v = cal_3d_position_error(pred3d_verts, gt3d_verts)
133
+ mve.append(error_v)
134
+ if has_kid and gt_kid_list[i]:
135
+ kid_mve.append(error_v)
136
+
137
+
138
+ #counting
139
+ step += 1
140
+ total_fp += falsePositive_count
141
+
142
+ img_idx = step + accelerator.process_index*len(eval_dataloader)*bs
143
+
144
+ if vis and (img_idx%vis_step == 0):
145
+ img_name = targets[idx]['img_path'].split('/')[-1].split('.')[0]
146
+ ori_img = tensor_to_BGR(unNormalize(samples[idx]).cpu())
147
+
148
+ # render mesh
149
+ colors = [(1.0, 1.0, 0.9)] * len(gt_verts)
150
+ gt_mesh_img = vis_meshes_img(img = ori_img.copy(),
151
+ verts = gt_verts,
152
+ smpl_faces = smpl_layer.faces,
153
+ cam_intrinsics = targets[idx]['cam_intrinsics'].reshape(3,3).detach().cpu(),
154
+ colors = colors)
155
+
156
+ colors = [(1.0, 0.6, 0.6)] * len(pred_verts)
157
+ for i in matched_verts_idx:
158
+ colors[i] = (0.7, 1.0, 0.4)
159
+
160
+ # colors = get_colors_rgb(len(pred_verts))
161
+ pred_mesh_img = vis_meshes_img(img = ori_img.copy(),
162
+ verts = pred_verts,
163
+ smpl_faces = smpl_layer.faces,
164
+ cam_intrinsics = outputs['pred_intrinsics'][idx].reshape(3,3).detach().cpu(),
165
+ colors = colors,
166
+ )
167
+
168
+
169
+ if 'enc_outputs' not in outputs:
170
+ pred_scale_img = np.zeros_like(pred_mesh_img)
171
+ else:
172
+ enc_out = outputs['enc_outputs']
173
+ h, w = enc_out['hw'][idx]
174
+ flatten_map = enc_out['scale_map'].split(enc_out['lens'])[idx].detach().cpu()
175
+
176
+ ys = enc_out['pos_y'].split(enc_out['lens'])[idx]
177
+ xs = enc_out['pos_x'].split(enc_out['lens'])[idx]
178
+ scale_map = torch.zeros((h,w,2))
179
+ scale_map[ys,xs] = flatten_map
180
+
181
+ pred_scale_img = vis_scale_img(img = ori_img.copy(),
182
+ scale_map = scale_map,
183
+ conf_thresh = model.sat_cfg['conf_thresh'],
184
+ patch_size=28)
185
+
186
+ pred_boxes = outputs['pred_boxes'][idx][select_queries_idx].detach().cpu()
187
+ pred_boxes = box_cxcywh_to_xyxy(pred_boxes) * model.input_size
188
+ pred_box_img = vis_boxes(ori_img.copy(), pred_boxes, color = (255,0,255))
189
+
190
+ # sat
191
+ sat_img = vis_sat(ori_img.copy(),
192
+ input_size = model.input_size,
193
+ patch_size = 14,
194
+ sat_dict = outputs['sat'],
195
+ bid = idx)
196
+
197
+ ori_img = pad_img(ori_img, model.input_size)
198
+
199
+ full_img = np.vstack([np.hstack([ori_img, sat_img]),
200
+ np.hstack([pred_scale_img, pred_box_img]),
201
+ np.hstack([gt_mesh_img, pred_mesh_img])])
202
+
203
+ cv2.imwrite(os.path.join(imgs_save_dir, f'{img_idx}_{img_name}.png'), full_img)
204
+
205
+ progress_bar.update(1)
206
+
207
+ if distributed:
208
+ mve = accelerator.gather_for_metrics(mve)
209
+ mpjpe = accelerator.gather_for_metrics(mpjpe)
210
+
211
+
212
+ total_miss_count = sum(accelerator.gather_for_metrics([total_miss_count]))
213
+ total_count = sum(accelerator.gather_for_metrics([total_count]))
214
+ total_fp = sum(accelerator.gather_for_metrics([total_fp]))
215
+
216
+ if has_kid:
217
+ kid_mve = accelerator.gather_for_metrics(kid_mve)
218
+ kid_mpjpe = accelerator.gather_for_metrics(kid_mpjpe)
219
+ kid_total_miss_count = sum(accelerator.gather_for_metrics([kid_total_miss_count]))
220
+ kid_total_count = sum(accelerator.gather_for_metrics([kid_total_count]))
221
+
222
+ if len(mpjpe) <= num_processes:
223
+ return "Failed to evaluate. Keep training!"
224
+ if has_kid and len(kid_mpjpe) <= num_processes:
225
+ return "Failed to evaluate. Keep training!"
226
+
227
+ precision, recall, f1 = compute_prf1(total_count,total_miss_count,total_fp)
228
+ error_dict = {}
229
+ error_dict['precision'] = precision
230
+ error_dict['recall'] = recall
231
+ error_dict['f1'] = f1
232
+
233
+ error_dict['MPJPE'] = round(sum(mpjpe)/(len(mpjpe)-num_processes), 1)
234
+ error_dict['NMJE'] = round(error_dict['MPJPE'] / (f1), 1)
235
+ error_dict['MVE'] = round(sum(mve)/(len(mve)-num_processes), 1)
236
+ error_dict['NMVE'] = round(error_dict['MVE'] / (f1), 1)
237
+
238
+ if has_kid:
239
+ kid_precision, kid_recall, kid_f1 = compute_prf1(kid_total_count,kid_total_miss_count,total_fp)
240
+ error_dict['kid_precision'] = kid_precision
241
+ error_dict['kid_recall'] = kid_recall
242
+ error_dict['kid_f1'] = kid_f1
243
+
244
+ error_dict['kid-MPJPE'] = round(sum(kid_mpjpe)/(len(kid_mpjpe)-num_processes), 1)
245
+ error_dict['kid-NMJE'] = round(error_dict['kid-MPJPE'] / (kid_f1), 1)
246
+ error_dict['kid-MVE'] = round(sum(kid_mve)/(len(kid_mve)-num_processes), 1)
247
+ error_dict['kid-NMVE'] = round(error_dict['kid-MVE'] / (kid_f1), 1)
248
+
249
+
250
+ if accelerator.is_main_process:
251
+ with open(os.path.join(results_save_path,'results.txt'),'w') as f:
252
+ for k,v in error_dict.items():
253
+ f.write(f'{k}: {v}\n')
254
+
255
+ return error_dict
256
+
257
+
258
+ def test_agora(model, eval_dataloader, conf_thresh,
259
+ vis = True, vis_step = 400, results_save_path = None,
260
+ distributed = False, accelerator = None):
261
+ assert results_save_path is not None
262
+ assert accelerator is not None
263
+
264
+ os.makedirs(os.path.join(results_save_path,'predictions'),exist_ok=True)
265
+ if vis:
266
+ imgs_save_dir = os.path.join(results_save_path, 'imgs')
267
+ os.makedirs(imgs_save_dir, exist_ok = True)
268
+ step = 0
269
+ cur_device = next(model.parameters()).device
270
+ smpl_layer = model.human_model
271
+
272
+ progress_bar = tqdm(total=len(eval_dataloader), disable=not accelerator.is_local_main_process)
273
+ progress_bar.set_description('testing')
274
+ for itr, (samples, targets) in enumerate(eval_dataloader):
275
+ samples=[sample.to(device = cur_device, non_blocking = True) for sample in samples]
276
+ with torch.no_grad():
277
+ outputs = model(samples, targets)
278
+ bs = len(targets)
279
+ for idx in range(bs):
280
+ #gt
281
+ img_name = targets[idx]['img_name'].split('.')[0]
282
+ #pred
283
+ select_queries_idx = torch.where(outputs['pred_confs'][idx] > conf_thresh)[0]
284
+ pred_j2ds = np.array(outputs['pred_j2ds'][idx][select_queries_idx].detach().to('cpu'))[:,:24,:]*(3840/model.input_size)
285
+ pred_j3ds = np.array(outputs['pred_j3ds'][idx][select_queries_idx].detach().to('cpu'))[:,:24,:]
286
+ pred_verts = np.array(outputs['pred_verts'][idx][select_queries_idx].detach().to('cpu'))
287
+ pred_poses = np.array(outputs['pred_poses'][idx][select_queries_idx].detach().to('cpu'))
288
+ pred_betas = np.array(outputs['pred_betas'][idx][select_queries_idx].detach().to('cpu'))
289
+
290
+ #visualization
291
+ step+=1
292
+ img_idx = step + accelerator.process_index*len(eval_dataloader)*bs
293
+ if vis and (img_idx%vis_step == 0):
294
+ ori_img = tensor_to_BGR(unNormalize(samples[idx]).cpu())
295
+ ori_img = pad_img(ori_img, model.input_size)
296
+
297
+ sat_img = vis_sat(ori_img.copy(),
298
+ input_size = model.input_size,
299
+ patch_size = 14,
300
+ sat_dict = outputs['sat'],
301
+ bid = idx)
302
+
303
+ colors = get_colors_rgb(len(pred_verts))
304
+ mesh_img = vis_meshes_img(img = ori_img.copy(),
305
+ verts = pred_verts,
306
+ smpl_faces = smpl_layer.faces,
307
+ colors = colors,
308
+ cam_intrinsics = outputs['pred_intrinsics'][idx].detach().cpu())
309
+
310
+ if 'enc_outputs' not in outputs:
311
+ pred_scale_img = np.zeros_like(ori_img)
312
+ else:
313
+ enc_out = outputs['enc_outputs']
314
+ h, w = enc_out['hw'][idx]
315
+ flatten_map = enc_out['scale_map'].split(enc_out['lens'])[idx].detach().cpu()
316
+
317
+ ys = enc_out['pos_y'].split(enc_out['lens'])[idx]
318
+ xs = enc_out['pos_x'].split(enc_out['lens'])[idx]
319
+ scale_map = torch.zeros((h,w,2))
320
+ scale_map[ys,xs] = flatten_map
321
+ pred_scale_img = vis_scale_img(img = ori_img.copy(),
322
+ scale_map = scale_map,
323
+ conf_thresh = model.sat_cfg['conf_thresh'],
324
+ patch_size=28)
325
+
326
+ full_img = np.vstack([np.hstack([ori_img, mesh_img]),
327
+ np.hstack([pred_scale_img, sat_img])])
328
+ cv2.imwrite(os.path.join(imgs_save_dir, f'{img_idx}_{img_name}.jpg'), full_img)
329
+
330
+
331
+ # submit
332
+ for pnum in range(len(pred_j2ds)):
333
+ smpl_dict = {}
334
+ # smpl_dict['age'] = 'kid'
335
+ smpl_dict['joints'] = pred_j2ds[pnum].reshape(24,2)
336
+ smpl_dict['params'] = {'transl': np.zeros((1,3)),
337
+ 'betas': pred_betas[pnum].reshape(1,10),
338
+ 'global_orient': pred_poses[pnum][:3].reshape(1,1,3),
339
+ 'body_pose': pred_poses[pnum][3:].reshape(1,23,3)}
340
+ # smpl_dict['verts'] = pred_verts[pnum].reshape(6890,3)
341
+ # smpl_dict['allSmplJoints3d'] = pred_j3ds[pnum].reshape(24,3)
342
+ with open(os.path.join(results_save_path,'predictions',f'{img_name}_personId_{pnum}.pkl'), 'wb') as f:
343
+ pickle.dump(smpl_dict, f)
344
+
345
+ progress_bar.update(1)
346
+
347
+ accelerator.print('Packing...')
348
+
349
+ folder_path = os.path.join(results_save_path,'predictions')
350
+ now = datetime.datetime.now()
351
+ timestamp = now.strftime("%Y%m%d_%H%M%S")
352
+ output_path = os.path.join(results_save_path,f'pred_{timestamp}.zip')
353
+ with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
354
+ for root, dirs, files in os.walk(folder_path):
355
+ for file in files:
356
+ file_path = os.path.join(root, file)
357
+ arcname = os.path.relpath(file_path, os.path.dirname(folder_path))
358
+ zipf.write(file_path, arcname)
359
+
360
+
361
+ return 'Results saved at: ' + os.path.join(results_save_path,'predictions')
362
+
engines/funcs/infer_funcs.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ from tqdm.auto import tqdm
4
+ import torch
5
+ import numpy as np
6
+ from utils.transforms import unNormalize
7
+ from utils.visualization import tensor_to_BGR, pad_img
8
+ from utils.visualization import vis_meshes_img, vis_boxes, vis_sat, vis_scale_img, get_colors_rgb
9
+ from utils.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh
10
+ import time
11
+ import cv2
12
+ import trimesh
13
+
14
+ def inference(model, infer_dataloader, conf_thresh, results_save_path = None,
15
+ distributed = False, accelerator = None):
16
+ assert results_save_path is not None
17
+ assert accelerator is not None
18
+
19
+ accelerator.print(f'Results will be saved at: {results_save_path}')
20
+ os.makedirs(results_save_path,exist_ok=True)
21
+ cur_device = next(model.parameters()).device
22
+ smpl_layer = model.human_model
23
+
24
+ progress_bar = tqdm(total=len(infer_dataloader), disable=not accelerator.is_local_main_process)
25
+ progress_bar.set_description('inference')
26
+
27
+ for itr, (samples, targets) in enumerate(infer_dataloader):
28
+ samples=[sample.to(device = cur_device, non_blocking = True) for sample in samples]
29
+ with torch.no_grad():
30
+ outputs = model(samples, targets)
31
+ bs = len(targets)
32
+ for idx in range(bs):
33
+ img_size = targets[idx]['img_size'].detach().cpu().int().numpy()
34
+ img_name = targets[idx]['img_path'].split('/')[-1].split('.')[0]
35
+
36
+ #pred
37
+ select_queries_idx = torch.where(outputs['pred_confs'][idx] > conf_thresh)[0]
38
+ pred_verts = outputs['pred_verts'][idx][select_queries_idx].detach().cpu().numpy()
39
+
40
+ ori_img = tensor_to_BGR(unNormalize(samples[idx]).cpu())
41
+ ori_img[img_size[0]:,:,:] = 255
42
+ ori_img[:,img_size[1]:,:] = 255
43
+ ori_img[img_size[0]:,img_size[1]:,:] = 255
44
+ ori_img = pad_img(ori_img, model.input_size, pad_color_offset=255)
45
+
46
+ sat_img = vis_sat(ori_img.copy(),
47
+ input_size = model.input_size,
48
+ patch_size = 14,
49
+ sat_dict = outputs['sat'],
50
+ bid = idx)[:img_size[0],:img_size[1]]
51
+
52
+ colors = get_colors_rgb(len(pred_verts))
53
+ pred_mesh_img = vis_meshes_img(img = ori_img.copy(),
54
+ verts = pred_verts,
55
+ smpl_faces = smpl_layer.faces,
56
+ cam_intrinsics = outputs['pred_intrinsics'][idx].reshape(3,3).detach().cpu(),
57
+ colors=colors)[:img_size[0],:img_size[1]]
58
+
59
+ if 'enc_outputs' not in outputs:
60
+ pred_scale_img = np.zeros_like(ori_img)[:img_size[0],:img_size[1]]
61
+ else:
62
+ enc_out = outputs['enc_outputs']
63
+ h, w = enc_out['hw'][idx]
64
+ flatten_map = enc_out['scale_map'].split(enc_out['lens'])[idx].detach().cpu()
65
+
66
+ ys = enc_out['pos_y'].split(enc_out['lens'])[idx]
67
+ xs = enc_out['pos_x'].split(enc_out['lens'])[idx]
68
+ scale_map = torch.zeros((h,w,2))
69
+ scale_map[ys,xs] = flatten_map
70
+
71
+ pred_scale_img = vis_scale_img(img = ori_img.copy(),
72
+ scale_map = scale_map,
73
+ conf_thresh = model.sat_cfg['conf_thresh'],
74
+ patch_size=28)[:img_size[0],:img_size[1]]
75
+
76
+ pred_boxes = outputs['pred_boxes'][idx][select_queries_idx].detach().cpu()
77
+ pred_boxes = box_cxcywh_to_xyxy(pred_boxes) * model.input_size
78
+ pred_box_img = vis_boxes(ori_img.copy(), pred_boxes, color = (255,0,255))[:img_size[0],:img_size[1]]
79
+
80
+ cv2.imwrite(os.path.join(results_save_path, f'{img_name}.png'), np.vstack([np.hstack([pred_box_img, pred_mesh_img]),
81
+ np.hstack([pred_scale_img, sat_img])]))
82
+
83
+
84
+ progress_bar.update(1)
85
+
86
+
figures/pipeline.png ADDED

Git LFS Details

  • SHA256: 1256facd5fe87da77b173205e5a4466081b5723aa52ee76736b89c51d6928153
  • Pointer size: 132 Bytes
  • Size of remote file: 1.9 MB
figures/qualitative_results.png ADDED

Git LFS Details

  • SHA256: 6139ee8ad610f83eaffabc4917e4a864180e298e8ccf438edaed8f9b117327e5
  • Pointer size: 132 Bytes
  • Size of remote file: 3.16 MB
figures/results.png ADDED

Git LFS Details

  • SHA256: 84a5282d996b7fabe3dd8a6ba09d23ddbf065fac414ebf8af6f61e8f16e68a45
  • Pointer size: 132 Bytes
  • Size of remote file: 1.82 MB
figures/results_3d.gif ADDED

Git LFS Details

  • SHA256: f7734e0a6f37aaabf03888b7d79d1b43b53c5b6a9cd7c1f1fe042079d6c9823d
  • Pointer size: 132 Bytes
  • Size of remote file: 7.09 MB
main.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import yaml
4
+ import numpy as np
5
+ from engines.engine import Engine
6
+
7
+
8
+ def get_args_parser():
9
+ parser = argparse.ArgumentParser('SAT-HMR', add_help=False)
10
+ parser.add_argument('--cfg', default=None, type=str)
11
+ parser.add_argument('--mode',default='train',type=str)
12
+
13
+ return parser
14
+
15
+ def update_args(args, cfg_path):
16
+ with open(cfg_path) as f:
17
+ config = yaml.safe_load(f)
18
+ args_dict = vars(args)
19
+ args_dict.update(config)
20
+ args = argparse.Namespace(**args_dict)
21
+ return args
22
+
23
+
24
+ if __name__ == '__main__':
25
+ parser = argparse.ArgumentParser('SAT-HMR training and evaluation script', parents=[get_args_parser()])
26
+ args = parser.parse_args()
27
+ assert args.cfg is not None
28
+ args = update_args(args, os.path.join('configs', 'run', f'{args.cfg}.yaml'))
29
+ args.exp_name = args.cfg
30
+ args = update_args(args, os.path.join('configs', 'models', f'{args.model}.yaml'))
31
+
32
+
33
+ if args.mode.lower() == 'train':
34
+ raise NotImplementedError
35
+ from accelerate.utils import set_seed
36
+ seed = args.seed
37
+ set_seed(args.seed)
38
+ engine = Engine(args, mode='train')
39
+ engine.train()
40
+
41
+ elif args.mode.lower() == 'eval':
42
+ raise NotImplementedError
43
+ engine = Engine(args, mode='eval')
44
+ engine.eval()
45
+
46
+ elif args.mode.lower() == 'infer':
47
+ engine = Engine(args, mode='infer')
48
+ engine.infer()
49
+
50
+ else:
51
+ print('Wrong mode!')
52
+ exit(1)
models/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Modified from DAB-DETR (https://github.com/IDEA-Research/DAB-DETR)
3
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
7
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
8
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
9
+ # ------------------------------------------------------------------------
10
+ # Modified from DETR (https://github.com/facebookresearch/detr)
11
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
12
+ # ------------------------------------------------------------------------
13
+
14
+ from .sat_model import build_sat_model
15
+
16
+
models/criterion.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from DAB-DETR (https://github.com/IDEA-Research/DAB-DETR)
2
+ import os
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ from utils import box_ops
7
+ from utils.misc import (NestedTensor, nested_tensor_from_tensor_list,
8
+ accuracy, get_world_size, interpolate,
9
+ is_dist_avail_and_initialized, inverse_sigmoid)
10
+
11
+ def focal_loss(inputs, targets, valid_mask = None, alpha: float = 0.25, gamma: float = 2):
12
+ """
13
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
14
+ Args:
15
+ inputs: A float tensor of arbitrary shape.
16
+ The predictions for each example.
17
+ targets: A float tensor with the same shape as inputs. Stores the binary
18
+ classification label for each element in inputs
19
+ (0 for the negative class and 1 for the positive class).
20
+ alpha: (optional) Weighting factor in range (0,1) to balance
21
+ positive vs negative examples. Default = -1 (no weighting).
22
+ gamma: Exponent of the modulating factor (1 - p_t) to
23
+ balance easy vs hard examples.
24
+ Returns:
25
+ Loss tensor
26
+ """
27
+ # prob = inputs.sigmoid()
28
+ # ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
29
+ prob = inputs
30
+ ce_loss = F.binary_cross_entropy(inputs, targets, reduction="none")
31
+ p_t = prob * targets + (1 - prob) * (1 - targets)
32
+ loss = ce_loss * ((1 - p_t) ** gamma)
33
+
34
+ if alpha >= 0:
35
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
36
+ loss = alpha_t * loss
37
+
38
+ # if valid_mask is not None:
39
+ # loss = loss * valid_mask
40
+
41
+ return loss.mean()
42
+
43
+ class SetCriterion(nn.Module):
44
+ """ This class computes the loss for DETR.
45
+ The process happens in two steps:
46
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
47
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
48
+ """
49
+ def __init__(self, matcher, weight_dict, losses = ['confs','boxes', 'poses','betas', 'j3ds','j2ds', 'depths', 'kid_offsets'],
50
+ focal_alpha=0.25, focal_gamma = 2.0, j2ds_norm_scale = 518):
51
+ """ Create the criterion.
52
+ Parameters:
53
+ num_classes: number of object categories, omitting the special no-object category
54
+ matcher: module able to compute a matching between targets and proposals
55
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
56
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
57
+ focal_alpha: alpha in Focal Loss
58
+ """
59
+ super().__init__()
60
+ self.matcher = matcher
61
+ self.losses = losses
62
+ if 'boxes' in losses and 'giou' not in weight_dict:
63
+ weight_dict.update({'giou': weight_dict['boxes']})
64
+ self.weight_dict = weight_dict
65
+
66
+
67
+ self.betas_weight = torch.tensor([2.56, 1.28, 0.64, 0.64, 0.32, 0.32, 0.32, 0.32, 0.32, 0.32]).unsqueeze(0).float()
68
+ self.focal_alpha = focal_alpha
69
+ self.focal_gamma = focal_gamma
70
+ self.j2ds_norm_scale = j2ds_norm_scale
71
+ self.device = None
72
+
73
+
74
+ def loss_boxes(self, loss, outputs, targets, indices, num_instances, **kwargs):
75
+ """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
76
+ targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
77
+ The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
78
+ """
79
+ assert 'pred_boxes' in outputs
80
+ assert loss == 'boxes'
81
+ idx = self._get_src_permutation_idx(indices)
82
+ valid_idx = torch.where(torch.cat([torch.ones(len(i), dtype=bool, device = self.device)*(loss in t) for t, (_, i) in zip(targets, indices)]))[0]
83
+
84
+ if len(valid_idx) == 0:
85
+ return {loss: torch.tensor(0.).to(self.device)}
86
+
87
+ src = outputs['pred_'+loss][idx][valid_idx]
88
+ target = torch.cat([t[loss][i] for t, (_, i) in zip(targets, indices) if loss in t], dim=0)
89
+ assert src.shape == target.shape
90
+
91
+ src_boxes = src
92
+ target_boxes = target
93
+
94
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
95
+
96
+ losses = {}
97
+ losses['boxes'] = loss_bbox.sum() / num_instances
98
+
99
+ loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
100
+ box_ops.box_cxcywh_to_xyxy(src_boxes),
101
+ box_ops.box_cxcywh_to_xyxy(target_boxes)))
102
+ losses['giou'] = loss_giou.sum() / num_instances
103
+
104
+ # # calculate the x,y and h,w loss
105
+ # with torch.no_grad():
106
+ # losses['loss_xy'] = loss_bbox[..., :2].sum() / num_boxes
107
+ # losses['loss_hw'] = loss_bbox[..., 2:].sum() / num_boxes
108
+
109
+
110
+ return losses
111
+
112
+ def loss_boxes_enc(self, loss, outputs, targets, indices, num_instances, **kwargs):
113
+ """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
114
+ targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
115
+ The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
116
+ """
117
+ assert 'pred_boxes' in outputs
118
+ assert loss == 'boxes_enc'
119
+ loss = 'boxes'
120
+
121
+ valid_idx = torch.where(torch.cat([torch.ones(len(i), dtype=bool, device = self.device)*(loss in t) for t, (_, i) in zip(targets, indices)]))[0]
122
+
123
+ if len(valid_idx) == 0:
124
+ return {loss: torch.tensor(0.).to(self.device)}
125
+
126
+ lens = outputs['lens']
127
+ pred_boxes = outputs['pred_boxes']
128
+ src = torch.cat([s[i] for s, (i, _) in zip(pred_boxes.split(lens), indices)], dim=0)[valid_idx]
129
+ target = torch.cat([t[loss][i] for t, (_, i) in zip(targets, indices) if loss in t], dim=0)
130
+ assert src.shape == target.shape
131
+
132
+ src_boxes = src
133
+ target_boxes = target
134
+
135
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
136
+
137
+ losses = {}
138
+ losses['boxes'] = loss_bbox.sum() / num_instances
139
+
140
+ loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
141
+ box_ops.box_cxcywh_to_xyxy(src_boxes),
142
+ box_ops.box_cxcywh_to_xyxy(target_boxes)))
143
+ losses['giou'] = loss_giou.sum() / num_instances
144
+
145
+ # # calculate the x,y and h,w loss
146
+ # with torch.no_grad():
147
+ # losses['loss_xy'] = loss_bbox[..., :2].sum() / num_boxes
148
+ # losses['loss_hw'] = loss_bbox[..., 2:].sum() / num_boxes
149
+
150
+
151
+ return losses
152
+
153
+ # For computing ['boxes', 'poses', 'betas', 'j3ds', 'j2ds'] losses
154
+ def loss_L1(self, loss, outputs, targets, indices, num_instances, **kwargs):
155
+ idx = self._get_src_permutation_idx(indices)
156
+ valid_idx = torch.where(torch.cat([torch.ones(len(i), dtype=bool, device = self.device)*(loss in t) for t, (_, i) in zip(targets, indices)]))[0]
157
+
158
+ if len(valid_idx) == 0:
159
+ return {loss: torch.tensor(0.).to(self.device)}
160
+
161
+ src = outputs['pred_'+loss][idx][valid_idx]
162
+ target = torch.cat([t[loss][i] for t, (_, i) in zip(targets, indices) if loss in t], dim=0)
163
+ assert src.shape == target.shape
164
+
165
+ losses = {}
166
+ loss_mask = None
167
+
168
+ if loss == 'j3ds':
169
+ # Root aligned
170
+ src = src - src[...,[0],:].clone()
171
+ target = target - target[...,[0],:].clone()
172
+ # Use 54 smpl joints
173
+ src = src[:,:54,:]
174
+ target = target[:,:54,:]
175
+ elif loss == 'j2ds':
176
+ src = src / self.j2ds_norm_scale
177
+ target = target / self.j2ds_norm_scale
178
+ # Need to exclude invalid kpts in 2d datasets
179
+ loss_mask = torch.cat([t['j2ds_mask'][i] for t, (_, i) in zip(targets, indices) if 'j2ds' in t], dim=0)
180
+ # Use 54 smpl joints
181
+ src = src[:,:54,:]
182
+ target = target[:,:54,:]
183
+ loss_mask = loss_mask[:,:54,:]
184
+
185
+ valid_loss = torch.abs(src-target)
186
+
187
+ # if loss == 'j2ds':
188
+ # print(src.shape)
189
+ # print(target.shape)
190
+ # print(num_instances)
191
+ # exit(0)
192
+
193
+ if loss_mask is not None:
194
+ valid_loss = valid_loss * loss_mask
195
+ if loss == 'betas':
196
+ valid_loss = valid_loss*self.betas_weight.to(src.device)
197
+
198
+ losses[loss] = valid_loss.flatten(1).mean(-1).sum()/num_instances
199
+
200
+
201
+
202
+ return losses
203
+
204
+ def loss_scale_map(self, loss, outputs, targets, indices, num_instances, **kwargs):
205
+ assert loss == 'scale_map'
206
+
207
+ pred_map = outputs['enc_outputs']['scale_map']
208
+ tgt_map = torch.cat([t['scale_map'] for t in targets], dim=0)
209
+ assert pred_map.shape == tgt_map.shape
210
+
211
+ labels = tgt_map[:,0]
212
+ pred_scales = pred_map[:,1]
213
+ tgt_scales = tgt_map[:, 1]
214
+
215
+ detection_valid_mask = labels.bool()
216
+ cur = 0
217
+ lens = [len(t['scale_map']) for t in targets]
218
+ for i, tgt in enumerate(targets):
219
+ if tgt['detect_all_people']:
220
+ detection_valid_mask[cur:cur+lens[i]] = True
221
+ cur += lens[i]
222
+
223
+
224
+ losses = {}
225
+ losses['map_confs'] = focal_loss(pred_map[:,0], labels, valid_mask=detection_valid_mask)/1.
226
+ losses['map_scales'] = torch.abs((pred_scales - tgt_scales)[torch.where(labels)[0]]).sum()/num_instances
227
+
228
+
229
+ return losses
230
+
231
+ def loss_confs(self, loss, outputs, targets, indices, num_instances, is_dn=False, **kwargs):
232
+ assert loss == 'confs'
233
+ idx = self._get_src_permutation_idx(indices)
234
+ pred_confs = outputs['pred_'+loss]
235
+
236
+ with torch.no_grad():
237
+ labels = torch.zeros_like(pred_confs)
238
+ labels[idx] = 1
239
+ detection_valid_mask = torch.zeros_like(pred_confs,dtype=bool)
240
+ detection_valid_mask[idx] = True
241
+ valid_batch_idx = torch.where(torch.tensor([t['detect_all_people'] for t in targets]))[0]
242
+ detection_valid_mask[valid_batch_idx] = True
243
+
244
+
245
+ losses = {}
246
+ if is_dn:
247
+ losses[loss] = focal_loss(pred_confs, labels) / num_instances
248
+ else:
249
+ losses[loss] = focal_loss(pred_confs, labels, valid_mask = detection_valid_mask) / num_instances
250
+
251
+ return losses
252
+
253
+ def loss_confs_enc(self, loss, outputs, targets, indices, num_instances, **kwargs):
254
+ assert loss == 'confs_enc'
255
+ loss = 'confs'
256
+
257
+ lens = outputs['lens']
258
+ pred_confs = outputs['pred_confs']
259
+ detection_valid_mask = torch.zeros_like(pred_confs,dtype=bool)
260
+ labels = torch.zeros_like(pred_confs)
261
+
262
+ cur = 0
263
+ idx = []
264
+ for i, (src, tgt) in enumerate(indices):
265
+ idx += (src + cur).tolist()
266
+ if targets[i]['detect_all_people']:
267
+ detection_valid_mask[cur:cur+lens[i]] = True
268
+ cur += lens[i]
269
+ detection_valid_mask[idx] = True
270
+ labels[idx] = 1
271
+
272
+ pred_confs = pred_confs.unsqueeze(0)
273
+ labels = labels.unsqueeze(0)
274
+ detection_valid_mask = detection_valid_mask.unsqueeze(0)
275
+
276
+ losses = {}
277
+ # losses[loss] = focal_loss(pred_confs, labels, valid_mask = detection_valid_mask)
278
+ losses[loss] = focal_loss(pred_confs, labels)
279
+ return losses
280
+
281
+ def loss_L2(self, loss, outputs, targets, indices, num_instances, **kwargs):
282
+ pass
283
+
284
+ def loss_absolute_depths(self, loss, outputs, targets, indices, num_instances, **kwargs):
285
+ assert loss == 'depths'
286
+ losses = {}
287
+ idx = self._get_src_permutation_idx(indices)
288
+ valid_idx = torch.where(torch.cat([torch.ones(len(i), dtype=bool, device = self.device)*(loss in t) for t, (_, i) in zip(targets, indices)]))[0]
289
+
290
+ if len(valid_idx) == 0:
291
+ return {loss: torch.tensor(0.).to(self.device)}
292
+
293
+ src = outputs['pred_'+loss][idx][valid_idx][...,[1]] # [d d/f]
294
+ target = torch.cat([t[loss][i] for t, (_, i) in zip(targets, indices) if loss in t], dim=0)[...,[0]]
295
+ target_focals = torch.cat([t['focals'][i] for t, (_, i) in zip(targets, indices) if loss in t], dim=0)
296
+
297
+ # print(src.shape, target.shape, target_focals.shape)
298
+
299
+ src = target_focals * src
300
+
301
+ assert src.shape == target.shape
302
+
303
+ valid_loss = torch.abs(1./(src + 1e-8) - 1./(target + 1e-8))
304
+ losses[loss] = valid_loss.flatten(1).mean(-1).sum()/num_instances
305
+ return losses
306
+
307
+
308
+ def _get_src_permutation_idx(self, indices):
309
+ # permute predictions following indices
310
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
311
+ src_idx = torch.cat([src for (src, _) in indices])
312
+ return batch_idx, src_idx
313
+
314
+ def _get_tgt_permutation_idx(self, indices):
315
+ # permute targets following indices
316
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
317
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
318
+ return batch_idx, tgt_idx
319
+
320
+ def get_loss(self, loss, outputs, targets, indices, num_instances, **kwargs):
321
+ loss_map = {
322
+ 'confs': self.loss_confs,
323
+ 'boxes': self.loss_boxes,
324
+ 'confs_enc': self.loss_confs_enc,
325
+ 'boxes_enc': self.loss_boxes_enc,
326
+ 'poses': self.loss_L1,
327
+ 'betas': self.loss_L1,
328
+ 'j3ds': self.loss_L1,
329
+ 'j2ds': self.loss_L1,
330
+ 'depths': self.loss_absolute_depths,
331
+ 'scale_map': self.loss_scale_map,
332
+ }
333
+ # assert loss in loss_map, f'do you really want to compute {loss} loss?'
334
+ return loss_map[loss](loss, outputs, targets, indices, num_instances, **kwargs)
335
+
336
+
337
+ def get_valid_instances(self, targets):
338
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
339
+ # Losses: 'confs','centers','anchors', 'poses', 'betas', 'j3ds', 'j2ds', 'depths', 'ages', 'heatmap'
340
+ num_valid_instances = {}
341
+ for loss in self.losses:
342
+ num_instances = 0
343
+ if loss != 'scale_map':
344
+ for t in targets:
345
+ num_instances += t['pnum'] if loss in t else 0
346
+ num_instances = torch.as_tensor([num_instances], dtype=torch.float, device=self.device)
347
+ else:
348
+ for t in targets:
349
+ num_instances += t['scale_map'][...,0].sum().item()
350
+ num_instances = torch.as_tensor([num_instances], dtype=torch.float, device=self.device)
351
+ if is_dist_avail_and_initialized():
352
+ torch.distributed.all_reduce(num_instances)
353
+ num_instances = torch.clamp(num_instances / get_world_size(), min=1).item()
354
+ num_valid_instances[loss] = num_instances
355
+ num_valid_instances['confs'] = 1.
356
+ return num_valid_instances
357
+
358
+ def prep_for_dn(self, dn_meta):
359
+ output_known = dn_meta['output_known']
360
+ num_dn_groups, pad_size = dn_meta['num_dn_group'], dn_meta['pad_size']
361
+ assert pad_size % num_dn_groups == 0
362
+ single_pad = pad_size//num_dn_groups
363
+
364
+ return output_known, single_pad, num_dn_groups
365
+
366
+ def forward(self, outputs, targets):
367
+ """ This performs the loss computation.
368
+ Parameters:
369
+ outputs: dict of tensors, see the output specification of the model for the format
370
+ targets: list of dicts, such that len(targets) == batch_size.
371
+ The expected keys in each dict depends on the losses applied, see each loss' doc
372
+ """
373
+ # remove invalid information in targets
374
+ for t in targets:
375
+ if not t['3d_valid']:
376
+ for key in ['betas', 'kid_offsets', 'poses', 'j3ds', 'depths', 'focals']:
377
+ if key in t:
378
+ del t[key]
379
+
380
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs' and k != 'enc_outputs' and k != 'sat'}
381
+ # Retrieve the matching between the outputs of the last layer and the targets
382
+ indices = self.matcher(outputs_without_aux, targets)
383
+ self.device = outputs['pred_poses'].device
384
+ num_valid_instances = self.get_valid_instances(targets)
385
+
386
+ # Compute all the requested losses
387
+ losses = {}
388
+
389
+ # prepare for dn loss
390
+ if 'dn_meta' in outputs:
391
+ dn_meta = outputs['dn_meta']
392
+ output_known, single_pad, scalar = self.prep_for_dn(dn_meta)
393
+
394
+ dn_pos_idx = []
395
+ dn_neg_idx = []
396
+ for i in range(len(targets)):
397
+ assert len(targets[i]['boxes']) > 0
398
+ # t = torch.range(0, len(targets[i]['labels']) - 1).long().to(self.device)
399
+ t = torch.arange(0, len(targets[i]['labels'])).long().to(self.device)
400
+ t = t.unsqueeze(0).repeat(scalar, 1)
401
+ tgt_idx = t.flatten()
402
+ output_idx = (torch.tensor(range(scalar)) * single_pad).long().to(self.device).unsqueeze(1) + t
403
+ output_idx = output_idx.flatten()
404
+
405
+ dn_pos_idx.append((output_idx, tgt_idx))
406
+ dn_neg_idx.append((output_idx + single_pad // 2, tgt_idx))
407
+
408
+ l_dict = {}
409
+ for loss in self.losses:
410
+ if loss == 'scale_map':
411
+ continue
412
+ l_dict.update(self.get_loss(loss, output_known, targets, dn_pos_idx, num_valid_instances[loss]*scalar, is_dn=True))
413
+
414
+ l_dict = {k + f'_dn': v for k, v in l_dict.items()}
415
+ losses.update(l_dict)
416
+
417
+ for loss in self.losses:
418
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_valid_instances[loss]))
419
+
420
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
421
+ if 'aux_outputs' in outputs:
422
+ for i, aux_outputs in enumerate(outputs['aux_outputs']):
423
+ indices = self.matcher(aux_outputs, targets)
424
+ for loss in self.losses:
425
+ if loss == 'scale_map':
426
+ continue
427
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_valid_instances[loss])
428
+ l_dict = {f'{k}.{i}': v for k, v in l_dict.items()}
429
+ losses.update(l_dict)
430
+
431
+ if 'dn_meta' in outputs:
432
+ if loss == 'scale_map':
433
+ continue
434
+ aux_outputs_known = output_known['aux_outputs'][i]
435
+ l_dict={}
436
+ for loss in self.losses:
437
+ l_dict.update(self.get_loss(loss, aux_outputs_known, targets, dn_pos_idx, num_valid_instances[loss]*scalar, is_dn=True))
438
+ l_dict = {k + f'_dn.{i}': v for k, v in l_dict.items()}
439
+ losses.update(l_dict)
440
+
441
+ # if 'scale_map' in outputs:
442
+ # enc_outputs = outputs['enc_outputs']
443
+ # indices = self.matcher.forward_enc(enc_outputs, targets)
444
+ # for loss in ['confs_enc', 'boxes_enc']:
445
+ # l_dict = self.get_loss(loss, enc_outputs, targets, indices, num_valid_instances[loss.replace('_enc','')])
446
+ # l_dict = {k + f'_enc': v for k, v in l_dict.items()}
447
+ # losses.update(l_dict)
448
+
449
+ return losses
models/decoder.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from DAB-DETR (https://github.com/IDEA-Research/DAB-DETR)
2
+
3
+ import math
4
+ import copy
5
+ import os
6
+ from typing import Optional, List
7
+ from utils.misc import inverse_sigmoid
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch.nn.functional import scaled_dot_product_attention
12
+ from torch import nn, Tensor
13
+ from torch.nn.init import constant_
14
+
15
+ from .position_encoding import position_encoding_xy
16
+
17
+ from xformers.ops import memory_efficient_attention, fmha
18
+
19
+
20
+
21
+ class MLP(nn.Module):
22
+ """ Very simple multi-layer perceptron (also called FFN)"""
23
+
24
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
25
+ super().__init__()
26
+ self.num_layers = num_layers
27
+ h = [hidden_dim] * (num_layers - 1)
28
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
29
+
30
+ def forward(self, x):
31
+ for i, layer in enumerate(self.layers):
32
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
33
+ return x
34
+
35
+ class TransformerDecoder(nn.Module):
36
+ def __init__(self, d_model=512, nhead=8, num_queries=300,
37
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.0,
38
+ activation="relu",
39
+ return_intermediate_dec=False, query_dim=4,
40
+ keep_query_pos=False, query_scale_type='cond_elewise',
41
+ modulate_hw_attn=True,
42
+ bbox_embed_diff_each_layer=True,
43
+ ):
44
+
45
+ super().__init__()
46
+
47
+
48
+ decoder_layer = XformerDecoderLayer(d_model, nhead, dim_feedforward,
49
+ dropout, activation, keep_query_pos=keep_query_pos)
50
+ decoder_norm = nn.LayerNorm(d_model)
51
+ self.decoder = XformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
52
+ return_intermediate=return_intermediate_dec,
53
+ d_model=d_model, query_dim=query_dim, keep_query_pos=keep_query_pos, query_scale_type=query_scale_type,
54
+ modulate_hw_attn=modulate_hw_attn,
55
+ bbox_embed_diff_each_layer=bbox_embed_diff_each_layer)
56
+
57
+ self._reset_parameters()
58
+ assert query_scale_type in ['cond_elewise', 'cond_scalar', 'fix_elewise']
59
+
60
+ self.d_model = d_model
61
+ self.nhead = nhead
62
+ self.dec_layers = num_decoder_layers
63
+ self.num_queries = num_queries
64
+
65
+ def _reset_parameters(self):
66
+ for p in self.parameters():
67
+ if p.dim() > 1:
68
+ nn.init.xavier_uniform_(p)
69
+
70
+ def mask2bias(self, mask, batch_size):
71
+ if mask is None:
72
+ return None
73
+
74
+ assert mask.dtype == torch.bool
75
+ assert mask.ndim == 2
76
+ L, S = mask.shape[0], mask.shape[1]
77
+ pad_size = (S + 7) // 8 * 8
78
+ bias = torch.zeros((batch_size, self.nhead, L, pad_size), device = mask.device)[:,:,:,:S]
79
+ bias.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
80
+ return bias
81
+
82
+
83
+ def forward(self, memory, memory_lens, tgt, tgt_lens, refpoint_embed, pos_embed, self_attn_mask):
84
+ self_attn_bias = self.mask2bias(self_attn_mask, batch_size=len(memory_lens))
85
+ hs, references = self.decoder(memory=memory, memory_lens=memory_lens,
86
+ tgt=tgt, tgt_lens=tgt_lens,
87
+ pos=pos_embed, refpoints_unsigmoid=refpoint_embed,
88
+ self_attn_bias = self_attn_bias)
89
+ return hs, references
90
+
91
+
92
+ class XformerDecoder(nn.Module):
93
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=True,
94
+ d_model=512, query_dim=4, keep_query_pos=False, query_scale_type='cond_elewise',
95
+ modulate_hw_attn=False,
96
+ bbox_embed_diff_each_layer=False,
97
+ ):
98
+ super().__init__()
99
+ self.layers = _get_clones(decoder_layer, num_layers)
100
+ self.num_layers = num_layers
101
+ self.norm = norm
102
+ self.return_intermediate = return_intermediate
103
+ assert return_intermediate
104
+ self.query_dim = query_dim
105
+
106
+ assert query_scale_type in ['cond_elewise', 'cond_scalar', 'fix_elewise']
107
+ self.query_scale_type = query_scale_type
108
+ if query_scale_type == 'cond_elewise':
109
+ self.query_scale = MLP(d_model, d_model, d_model, 2)
110
+ elif query_scale_type == 'cond_scalar':
111
+ self.query_scale = MLP(d_model, d_model, 1, 2)
112
+ elif query_scale_type == 'fix_elewise':
113
+ self.query_scale = nn.Embedding(num_layers, d_model)
114
+ else:
115
+ raise NotImplementedError("Unknown query_scale_type: {}".format(query_scale_type))
116
+
117
+ self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
118
+
119
+ self.bbox_embed = None
120
+ self.d_model = d_model
121
+ self.modulate_hw_attn = modulate_hw_attn
122
+ self.bbox_embed_diff_each_layer = bbox_embed_diff_each_layer
123
+
124
+
125
+ if modulate_hw_attn:
126
+ self.ref_anchor_head = MLP(d_model, d_model, 2, 2)
127
+
128
+
129
+ if not keep_query_pos:
130
+ for layer_id in range(num_layers - 1):
131
+ self.layers[layer_id + 1].ca_qpos_proj = None
132
+
133
+ def forward(self, memory, memory_lens, tgt, tgt_lens,
134
+ pos: Optional[Tensor] = None,
135
+ refpoints_unsigmoid: Optional[Tensor] = None, # L_tgt, 4
136
+ self_attn_bias = None):
137
+ B, num_queries = len(tgt_lens), tgt_lens[0]
138
+ output = tgt
139
+
140
+ intermediate = []
141
+ reference_points = refpoints_unsigmoid.sigmoid()
142
+ ref_points = [reference_points.view(B, num_queries, self.query_dim)]
143
+
144
+ # import ipdb; ipdb.set_trace()
145
+
146
+ for layer_id, layer in enumerate(self.layers):
147
+ obj_center = reference_points[:, :self.query_dim] # [L_tgt, 4]
148
+ # get sine embedding for the query vector
149
+ xy_embed = position_encoding_xy(obj_center[:,0], obj_center[:,1], self.d_model)
150
+ wh_embed = position_encoding_xy(obj_center[:,2], obj_center[:,3], self.d_model)
151
+ query_sine_embed = torch.cat([xy_embed,wh_embed],dim=1) #[L_tgt, 2*d_model]
152
+ query_pos = self.ref_point_head(query_sine_embed)
153
+
154
+ # For the first decoder layer, we do not apply transformation over p_s
155
+ if self.query_scale_type != 'fix_elewise':
156
+ if layer_id == 0:
157
+ pos_transformation = 1
158
+ else:
159
+ pos_transformation = self.query_scale(output)
160
+ else:
161
+ pos_transformation = self.query_scale.weight[layer_id]
162
+
163
+ # apply transformation
164
+ query_sine_embed = query_sine_embed[:,:self.d_model] * pos_transformation
165
+
166
+ # modulated HW attentions
167
+ if self.modulate_hw_attn:
168
+ refHW_cond = self.ref_anchor_head(output).sigmoid() # nq, bs, 2
169
+ query_sine_embed[..., self.d_model // 2:] *= (refHW_cond[..., 0] / obj_center[..., 2]).unsqueeze(-1)
170
+ query_sine_embed[..., :self.d_model // 2] *= (refHW_cond[..., 1] / obj_center[..., 3]).unsqueeze(-1)
171
+
172
+
173
+ output = layer(memory=memory, memory_lens=memory_lens,
174
+ tgt=output, tgt_lens=tgt_lens,
175
+ pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed,
176
+ is_first=(layer_id == 0),
177
+ self_attn_bias = self_attn_bias)
178
+
179
+ # iter update
180
+ if self.bbox_embed is not None:
181
+ if self.bbox_embed_diff_each_layer:
182
+ tmp = self.bbox_embed[layer_id](self.norm(output))
183
+ else:
184
+ tmp = self.bbox_embed(self.norm(output))
185
+ # import ipdb; ipdb.set_trace()
186
+ tmp[..., :self.query_dim] += inverse_sigmoid(reference_points)
187
+ new_reference_points = tmp[..., :self.query_dim].sigmoid()
188
+ if layer_id != self.num_layers - 1:
189
+ ref_points.append(new_reference_points.view(B, num_queries, self.query_dim))
190
+ reference_points = new_reference_points.detach()
191
+
192
+ if self.return_intermediate:
193
+ intermediate.append(self.norm(output).view(B, num_queries, self.d_model))
194
+
195
+ # if self.norm is not None:
196
+ # output = self.norm(output)
197
+ # if self.return_intermediate:
198
+ # intermediate.pop()
199
+ # intermediate.append(output.view(B, num_queries, self.d_model))
200
+
201
+ if self.return_intermediate:
202
+ if self.bbox_embed is not None:
203
+ return [
204
+ torch.stack(intermediate),
205
+ torch.stack(ref_points),
206
+ ]
207
+ else:
208
+ return [
209
+ torch.stack(intermediate),
210
+ reference_points.unsqueeze(0)
211
+ ]
212
+
213
+ return output.unsqueeze(0)
214
+
215
+ class XformerDecoderLayer(nn.Module):
216
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.0,
217
+ activation="relu", keep_query_pos=False):
218
+ super().__init__()
219
+ # Decoder Self-Attention
220
+ self.sa_qcontent_proj = nn.Linear(d_model, d_model)
221
+ self.sa_qpos_proj = nn.Linear(d_model, d_model)
222
+ self.sa_kcontent_proj = nn.Linear(d_model, d_model)
223
+ self.sa_kpos_proj = nn.Linear(d_model, d_model)
224
+ self.sa_v_proj = nn.Linear(d_model, d_model)
225
+ self.sa_out_proj = nn.Linear(d_model, d_model)
226
+ constant_(self.sa_out_proj.bias, 0.)
227
+
228
+ self.norm1 = nn.LayerNorm(d_model)
229
+ self.dropout1 = nn.Dropout(dropout)
230
+
231
+ # Decoder Cross-Attention
232
+ self.ca_qcontent_proj = nn.Linear(d_model, d_model)
233
+ self.ca_qpos_proj = nn.Linear(d_model, d_model)
234
+ self.ca_kcontent_proj = nn.Linear(d_model, d_model)
235
+ self.ca_kpos_proj = nn.Linear(d_model, d_model)
236
+ self.ca_v_proj = nn.Linear(d_model, d_model)
237
+ self.ca_qpos_sine_proj = nn.Linear(d_model, d_model)
238
+ self.ca_out_proj = nn.Linear(d_model, d_model)
239
+ constant_(self.ca_out_proj.bias, 0.)
240
+
241
+ self.d_model = d_model
242
+ self.nhead = nhead
243
+ assert self.d_model%self.nhead == 0
244
+
245
+ # Implementation of Feedforward model
246
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
247
+ self.dropout = nn.Dropout(dropout)
248
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
249
+
250
+
251
+ self.norm2 = nn.LayerNorm(d_model)
252
+ self.norm3 = nn.LayerNorm(d_model)
253
+ self.dropout2 = nn.Dropout(dropout)
254
+ self.dropout3 = nn.Dropout(dropout)
255
+
256
+ self.activation = _get_activation_fn(activation)
257
+ self.keep_query_pos = keep_query_pos
258
+
259
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
260
+ return tensor if pos is None else tensor + pos
261
+
262
+ def forward(self, memory, memory_lens, pos,
263
+ tgt, tgt_lens, query_pos, query_sine_embed,
264
+ is_first=False,
265
+ self_attn_bias=None):
266
+ # self_attn_bias is only used for dn_training
267
+ # 'True' indicates that the element should take part in attention
268
+
269
+ B, num_queries = len(tgt_lens), tgt_lens[0]
270
+ L_mem, C_mem = memory.shape
271
+ L_tgt, C_tgt = tgt.shape
272
+ assert C_mem == C_tgt
273
+
274
+ # ========== Begin of Self-Attention =============
275
+ tgt_b4n = tgt
276
+ tgt = self.norm1(tgt)
277
+
278
+ q_content = self.sa_qcontent_proj(tgt)
279
+ q_pos = self.sa_qpos_proj(query_pos)
280
+ k_content = self.sa_kcontent_proj(tgt)
281
+ k_pos = self.sa_kpos_proj(query_pos)
282
+ v = self.sa_v_proj(tgt)
283
+
284
+ q = q_content + q_pos
285
+ k = k_content + k_pos
286
+
287
+ q = q.view(B, num_queries, self.nhead, self.d_model // self.nhead)
288
+ k = k.view(B, num_queries, self.nhead, self.d_model // self.nhead)
289
+ v = v.view(B, num_queries, self.nhead, self.d_model // self.nhead)
290
+
291
+ tgt2 = memory_efficient_attention(q, k, v, attn_bias=self_attn_bias)
292
+ tgt2 = self.sa_out_proj(tgt2.view(L_tgt, self.d_model))
293
+
294
+ tgt = tgt_b4n + self.dropout1(tgt2)
295
+ # ========== End of Self-Attention =============
296
+
297
+
298
+
299
+ # ========== Begin of Cross-Attention =============
300
+ tgt_b4n = tgt
301
+ tgt = self.norm2(tgt)
302
+
303
+ q_content = self.ca_qcontent_proj(tgt)
304
+ k_content = self.ca_kcontent_proj(memory)
305
+ v = self.ca_v_proj(memory)
306
+
307
+ k_pos = self.ca_kpos_proj(pos)
308
+
309
+ # For the first decoder layer, we concatenate the positional embedding predicted from
310
+ # the object query (the positional embedding) into the original query (key) in DETR.
311
+ if is_first or self.keep_query_pos:
312
+ q_pos = self.ca_qpos_proj(query_pos)
313
+ q = q_content + q_pos
314
+ k = k_content + k_pos
315
+ else:
316
+ q = q_content
317
+ k = k_content
318
+
319
+ q = q.view(1, L_tgt, self.nhead, self.d_model//self.nhead)
320
+ query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
321
+ query_sine_embed = query_sine_embed.view(1, L_tgt, self.nhead, self.d_model//self.nhead)
322
+ q = torch.cat([q, query_sine_embed], dim=3)
323
+
324
+ k = k.view(1, L_mem, self.nhead, self.d_model//self.nhead)
325
+ k_pos = k_pos.view(1, L_mem, self.nhead, self.d_model//self.nhead)
326
+ k = torch.cat([k, k_pos], dim=3)
327
+
328
+ v = v.view(1, L_mem, self.nhead, self.d_model//self.nhead)
329
+
330
+ attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens(q_seqlen = tgt_lens, kv_seqlen = memory_lens)
331
+ tgt2 = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
332
+ tgt2 = self.ca_out_proj(tgt2.view(L_tgt, self.d_model))
333
+
334
+ tgt = tgt_b4n + self.dropout2(tgt2)
335
+ # ========== End of Cross-Attention =============
336
+
337
+ # FFN
338
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm3(tgt)))))
339
+ tgt = tgt + self.dropout3(tgt2)
340
+
341
+ return tgt
342
+
343
+
344
+ def _get_activation_fn(activation):
345
+ """Return an activation function given a string"""
346
+ if activation == "relu":
347
+ return F.relu
348
+ if activation == "gelu":
349
+ return F.gelu
350
+ if activation == "glu":
351
+ return F.glu
352
+ if activation == "prelu":
353
+ return nn.PReLU()
354
+ if activation == "selu":
355
+ return F.selu
356
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
357
+
358
+ def _get_clones(module, N):
359
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
360
+
361
+
362
+ def build_decoder(args):
363
+ return TransformerDecoder(
364
+ d_model=args.hidden_dim,
365
+ dropout=args.dropout,
366
+ nhead=args.nheads,
367
+ num_queries=args.num_queries,
368
+ dim_feedforward=args.dim_feedforward,
369
+ num_decoder_layers=args.dec_layers,
370
+ return_intermediate_dec=True,
371
+ query_dim=4,
372
+ activation=args.transformer_activation
373
+ )
374
+
375
+
376
+ def torch_attention(query, key, value, attn_bias = None):
377
+ scale = 1.0 / query.shape[-1] ** 0.5
378
+ query = query * scale
379
+ query = query.transpose(1, 2)
380
+ key = key.transpose(1, 2)
381
+ value = value.transpose(1, 2)
382
+ attn = query @ key.transpose(-2, -1)
383
+ if attn_bias is not None:
384
+ attn = attn + attn_bias
385
+ attn = attn.softmax(-1)
386
+ # attn = F.dropout(attn, p)
387
+ attn = attn @ value
388
+ return attn.transpose(1, 2)
models/dn_components.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from DINO (https://github.com/IDEA-Research/DINO)
2
+
3
+
4
+ import torch
5
+ from utils.misc import (NestedTensor, nested_tensor_from_tensor_list,
6
+ accuracy, get_world_size, interpolate,
7
+ is_dist_avail_and_initialized, inverse_sigmoid)
8
+ from utils import box_ops
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def prepare_for_cdn(targets, dn_cfg, num_queries, hidden_dim, dn_enc):
13
+ """
14
+ A major difference of DINO from DN-DETR is that the author process pattern embedding pattern embedding in its detector
15
+ forward function and use learnable tgt embedding, so we change this function a little bit.
16
+ :param dn_args: targets, dn_number, label_noise_ratio, box_noise_scale
17
+ :param training: if it is training or inference
18
+ :param num_queries: number of queires
19
+ :param num_classes: number of classes
20
+ :param hidden_dim: transformer hidden dim
21
+ :param label_enc: encode labels in dn
22
+ :return:
23
+ """
24
+ device = targets[0]['boxes'].device
25
+
26
+ dn_number = dn_cfg['dn_number']
27
+ box_noise_scale = dn_cfg['box_noise_scale']
28
+ tgt_noise_scale = dn_cfg['tgt_noise_scale']
29
+ known = [(torch.ones_like(t['labels'])) for t in targets]
30
+ batch_size = len(known)
31
+ known_num = [sum(k) for k in known]
32
+
33
+ if int(max(known_num)) == 0:
34
+ dn_number = 1
35
+ else:
36
+ if dn_number >= 100:
37
+ dn_number = dn_number // (int(max(known_num) * 2))
38
+ elif dn_number < 1:
39
+ dn_number = 1
40
+ if dn_number == 0:
41
+ dn_number = 1
42
+
43
+
44
+ unmask_bbox = torch.cat(known)
45
+
46
+ boxes = torch.cat([t['boxes'] for t in targets])
47
+ assert boxes.ndim == 2
48
+ batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)])
49
+ known_indice = torch.nonzero(unmask_bbox)
50
+ known_indice = known_indice.view(-1)
51
+ known_indice = known_indice.repeat(2 * dn_number, 1).view(-1)
52
+ known_bid = batch_idx.repeat(2 * dn_number, 1).view(-1)
53
+
54
+ single_pad = int(max(known_num))
55
+ pad_size = int(single_pad * 2 * dn_number)
56
+ positive_idx = torch.tensor(range(len(boxes))).long().to(device=device).unsqueeze(0).repeat(dn_number, 1)
57
+ positive_idx += (torch.tensor(range(dn_number)) * len(boxes) * 2).long().to(device=device).unsqueeze(1)
58
+ positive_idx = positive_idx.flatten()
59
+ negative_idx = positive_idx + len(boxes)
60
+
61
+ # box queries
62
+ known_bboxs = boxes.repeat(2 * dn_number, 1)
63
+ known_bbox_expand = known_bboxs.clone()
64
+ if box_noise_scale > 0:
65
+ known_bbox_ = torch.zeros_like(known_bboxs)
66
+ known_bbox_[:, :2] = known_bboxs[:, :2] - known_bboxs[:, 2:] / 2
67
+ known_bbox_[:, 2:] = known_bboxs[:, :2] + known_bboxs[:, 2:] / 2
68
+
69
+ diff = torch.zeros_like(known_bboxs)
70
+ diff[:, :2] = known_bboxs[:, 2:] / 2
71
+ diff[:, 2:] = known_bboxs[:, 2:] / 2
72
+
73
+ rand_sign = torch.randint_like(known_bboxs, low=0, high=2, dtype=torch.float32) * 2.0 - 1.0
74
+ rand_part = torch.rand_like(known_bboxs)
75
+ rand_part[negative_idx] += 1.0
76
+ rand_part *= rand_sign
77
+ known_bbox_ = known_bbox_ + torch.mul(rand_part,
78
+ diff).to(device=device) * box_noise_scale
79
+ known_bbox_ = known_bbox_.clamp(min=0.0, max=1.0)
80
+ known_bbox_expand[:, :2] = (known_bbox_[:, :2] + known_bbox_[:, 2:]) / 2
81
+ known_bbox_expand[:, 2:] = known_bbox_[:, 2:] - known_bbox_[:, :2]
82
+ input_bbox_embed = inverse_sigmoid(known_bbox_expand)
83
+
84
+ # tgt queries
85
+ if dn_cfg['tgt_embed_type'] == 'labels':
86
+ labels = torch.cat([t['labels'] for t in targets])
87
+ known_labels = labels.repeat(2 * dn_number, 1).view(-1)
88
+ known_labels_expaned = known_labels.clone()
89
+ if tgt_noise_scale > 0:
90
+ p = torch.rand_like(known_labels_expaned.float())
91
+ chosen_indice = torch.nonzero(p < tgt_noise_scale).view(-1)
92
+ new_label = torch.randint_like(chosen_indice, 0, dn_cfg['dn_labelbook_size']) # randomly put a new one here
93
+ known_labels_expaned.scatter_(0, chosen_indice, new_label)
94
+ m = known_labels_expaned.long().to(device=device)
95
+ input_tgt_embed = dn_enc(m)
96
+ elif dn_cfg['tgt_embed_type'] == 'params':
97
+ poses = torch.cat([t['poses'] for t in targets])
98
+ betas = torch.cat([t['betas'] for t in targets])
99
+ params = torch.cat([poses, betas], dim=-1)
100
+ assert params.ndim == 2
101
+ known_params = params.repeat(2 * dn_number, 1)
102
+ known_params_expaned = known_params.clone()
103
+ if tgt_noise_scale > 0:
104
+ rand_sign = torch.randint_like(known_params, low=0, high=2, dtype=torch.float32) * 2.0 - 1.0
105
+ rand_part = torch.rand_like(known_params)
106
+ rand_part[negative_idx] += 1.0
107
+ rand_part *= rand_sign
108
+ known_params_expaned = known_params_expaned + rand_part * tgt_noise_scale
109
+ m = known_params_expaned.to(device=device)
110
+ input_tgt_embed = dn_enc(m)
111
+
112
+
113
+ padding_tgt = torch.zeros((pad_size, hidden_dim), device=device)
114
+ padding_bbox = torch.zeros((pad_size, 4), device=device)
115
+
116
+ input_query_tgt = padding_tgt.repeat(batch_size, 1, 1)
117
+ input_query_bbox = padding_bbox.repeat(batch_size, 1, 1)
118
+
119
+ map_known_indice = torch.tensor([]).to(device=device)
120
+ if len(known_num):
121
+ map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num]) # [1,2, 1,2,3]
122
+ map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(2 * dn_number)]).long()
123
+ if len(known_bid):
124
+ input_query_tgt[(known_bid.long(), map_known_indice)] = input_tgt_embed
125
+ input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed
126
+
127
+
128
+ # prepare attn_mask
129
+ tgt_size = pad_size + num_queries
130
+ attn_mask = torch.zeros((tgt_size, tgt_size), dtype=bool, device=device)
131
+ # match query cannot see the reconstruct
132
+ attn_mask[pad_size:, :pad_size] = True
133
+ # reconstruct cannot see each other
134
+ for i in range(dn_number):
135
+ if i == 0:
136
+ attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True
137
+ if i == dn_number - 1:
138
+ attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * i * 2] = True
139
+ else:
140
+ attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True
141
+ attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * 2 * i] = True
142
+
143
+ dn_meta = {
144
+ 'pad_size': pad_size,
145
+ 'num_dn_group': dn_number,
146
+ }
147
+
148
+ return input_query_tgt, input_query_bbox, attn_mask, dn_meta
149
+
150
+
151
+ def dn_post_process(pred_poses, pred_betas,
152
+ pred_boxes, pred_confs,
153
+ pred_j3ds, pred_j2ds, pred_depths,
154
+ pred_verts, pred_transl,
155
+ dn_meta, aux_loss, _set_aux_loss):
156
+ """
157
+ post process of dn after output from the transformer
158
+ put the dn part in the dn_meta
159
+ """
160
+ assert dn_meta['pad_size'] > 0
161
+ pad_size = dn_meta['pad_size']
162
+
163
+ known_poses, pred_poses = pred_poses[:,:,:pad_size], pred_poses[:,:,pad_size:]
164
+ known_betas, pred_betas = pred_betas[:,:,:pad_size], pred_betas[:,:,pad_size:]
165
+ known_boxes, pred_boxes = pred_boxes[:,:,:pad_size], pred_boxes[:,:,pad_size:]
166
+ known_confs, pred_confs = pred_confs[:,:,:pad_size], pred_confs[:,:,pad_size:]
167
+ known_j3ds, pred_j3ds = pred_j3ds[:,:,:pad_size], pred_j3ds[:,:,pad_size:]
168
+ known_j2ds, pred_j2ds = pred_j2ds[:,:,:pad_size], pred_j2ds[:,:,pad_size:]
169
+ known_depths, pred_depths = pred_depths[:,:,:pad_size], pred_depths[:,:,pad_size:]
170
+
171
+ known_verts, pred_verts = pred_verts[:,:pad_size], pred_verts[:,pad_size:]
172
+ known_transl, pred_transl = pred_transl[:,:pad_size], pred_transl[:,pad_size:]
173
+
174
+ out = {'pred_poses': known_poses[-1], 'pred_betas': known_betas[-1],
175
+ 'pred_boxes': known_boxes[-1], 'pred_confs': known_confs[-1],
176
+ 'pred_j3ds': known_j3ds[-1], 'pred_j2ds': known_j2ds[-1],
177
+ 'pred_depths': known_depths[-1]}
178
+
179
+ if aux_loss:
180
+ out['aux_outputs'] = _set_aux_loss(known_poses, known_betas,
181
+ known_boxes, known_confs,
182
+ known_j3ds, known_j2ds, known_depths)
183
+
184
+ dn_meta['output_known'] = out
185
+
186
+
187
+ return pred_poses, pred_betas,\
188
+ pred_boxes, pred_confs,\
189
+ pred_j3ds, pred_j2ds,\
190
+ pred_depths, pred_verts,\
191
+ pred_transl,
192
+
193
+
models/encoders/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from DINOv2 (https://github.com/facebookresearch/dinov2)
2
+ from models.encoders.dinov2.models.vision_transformer import vit_base, vit_large
3
+ import torch
4
+ from configs.paths import dinov2_vitb14_path, dinov2_vitl14_path
5
+ import copy
6
+
7
+ def build_encoder(args):
8
+ num_additional_blocks = 0
9
+ if args.sat_cfg['use_sat'] and args.sat_cfg['use_additional_blocks']:
10
+ num_additional_blocks = args.sat_cfg['get_map_layer']
11
+
12
+ weights = None
13
+ if args.encoder == 'vitb':
14
+ model = vit_base(img_size = 518,
15
+ patch_size = 14,
16
+ init_values = 1.0,
17
+ ffn_layer = "mlp",
18
+ block_chunks = 0,
19
+ num_register_tokens = 0,
20
+ interpolate_antialias = False,
21
+ interpolate_offset = 0.1,
22
+ num_additional_blocks = num_additional_blocks)
23
+ if args.mode.lower() == 'train':
24
+ weights = torch.load(dinov2_vitb14_path)
25
+ elif args.encoder == 'vitl':
26
+ model = vit_large(img_size = 518,
27
+ patch_size = 14,
28
+ init_values = 1.0,
29
+ ffn_layer = "mlp",
30
+ block_chunks = 0,
31
+ num_register_tokens = 0,
32
+ interpolate_antialias = False,
33
+ interpolate_offset = 0.1,
34
+ num_additional_blocks = num_additional_blocks)
35
+ if args.mode.lower() == 'train':
36
+ weights = torch.load(dinov2_vitl14_path)
37
+ else:
38
+ raise NotImplementedError
39
+
40
+ if weights is not None:
41
+ if args.sat_cfg['use_sat'] and args.sat_cfg['use_additional_blocks']:
42
+ add_blocks_weights(weights, args.sat_cfg['get_map_layer'])
43
+ print('Loading pretrained DINOv2...')
44
+ model.load_state_dict(weights,strict=True)
45
+
46
+ return model
47
+
48
+ def add_blocks_weights(weights, num_layers):
49
+ for k in list(weights.keys()):
50
+ if k.startswith('blocks') and int(k.split('.')[1]) < num_layers:
51
+ new_k = k.replace('blocks', 'additional_blocks')
52
+ weights[new_k] = copy.deepcopy(weights[k])
models/encoders/dinov2/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dino_head import DINOHead
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
models/encoders/dinov2/layers/attention.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
22
+ try:
23
+ if XFORMERS_ENABLED:
24
+ from xformers.ops import memory_efficient_attention, unbind
25
+
26
+ XFORMERS_AVAILABLE = True
27
+ warnings.warn("xFormers is available (Attention)")
28
+ else:
29
+ warnings.warn("xFormers is disabled (Attention)")
30
+ raise ImportError
31
+ except ImportError:
32
+ XFORMERS_AVAILABLE = False
33
+ warnings.warn("xFormers is not available (Attention)")
34
+
35
+
36
+ class Attention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int = 8,
41
+ qkv_bias: bool = False,
42
+ proj_bias: bool = True,
43
+ attn_drop: float = 0.0,
44
+ proj_drop: float = 0.0,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ self.scale = head_dim**-0.5
50
+
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+
56
+ def forward(self, x: Tensor) -> Tensor:
57
+ B, N, C = x.shape
58
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
59
+
60
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
61
+ attn = q @ k.transpose(-2, -1)
62
+
63
+ attn = attn.softmax(dim=-1)
64
+ attn = self.attn_drop(attn)
65
+
66
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67
+ x = self.proj(x)
68
+ x = self.proj_drop(x)
69
+ return x
70
+
71
+
72
+ class MemEffAttention(Attention):
73
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
74
+ if not XFORMERS_AVAILABLE:
75
+ if attn_bias is not None:
76
+ raise AssertionError("xFormers is required for using nested tensors")
77
+ return super().forward(x)
78
+
79
+ B, N, C = x.shape
80
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
81
+
82
+ q, k, v = unbind(qkv, 2)
83
+
84
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
85
+ x = x.reshape([B, N, C])
86
+
87
+ x = self.proj(x)
88
+ x = self.proj_drop(x)
89
+ return x
models/encoders/dinov2/layers/block.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention, MemEffAttention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ logger = logging.getLogger("dinov2")
25
+
26
+
27
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
28
+ try:
29
+ if XFORMERS_ENABLED:
30
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
31
+
32
+ XFORMERS_AVAILABLE = True
33
+ warnings.warn("xFormers is available (Block)")
34
+ else:
35
+ warnings.warn("xFormers is disabled (Block)")
36
+ raise ImportError
37
+ except ImportError:
38
+ XFORMERS_AVAILABLE = False
39
+
40
+ warnings.warn("xFormers is not available (Block)")
41
+
42
+
43
+ class Block(nn.Module):
44
+ def __init__(
45
+ self,
46
+ dim: int,
47
+ num_heads: int,
48
+ mlp_ratio: float = 4.0,
49
+ qkv_bias: bool = False,
50
+ proj_bias: bool = True,
51
+ ffn_bias: bool = True,
52
+ drop: float = 0.0,
53
+ attn_drop: float = 0.0,
54
+ init_values=None,
55
+ drop_path: float = 0.0,
56
+ act_layer: Callable[..., nn.Module] = nn.GELU,
57
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
58
+ attn_class: Callable[..., nn.Module] = Attention,
59
+ ffn_layer: Callable[..., nn.Module] = Mlp,
60
+ ) -> None:
61
+ super().__init__()
62
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
63
+ self.norm1 = norm_layer(dim)
64
+ self.attn = attn_class(
65
+ dim,
66
+ num_heads=num_heads,
67
+ qkv_bias=qkv_bias,
68
+ proj_bias=proj_bias,
69
+ attn_drop=attn_drop,
70
+ proj_drop=drop,
71
+ )
72
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
73
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
74
+
75
+ self.norm2 = norm_layer(dim)
76
+ mlp_hidden_dim = int(dim * mlp_ratio)
77
+ self.mlp = ffn_layer(
78
+ in_features=dim,
79
+ hidden_features=mlp_hidden_dim,
80
+ act_layer=act_layer,
81
+ drop=drop,
82
+ bias=ffn_bias,
83
+ )
84
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
85
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
86
+
87
+ self.sample_drop_ratio = drop_path
88
+
89
+ def forward(self, x: Tensor) -> Tensor:
90
+ def attn_residual_func(x: Tensor) -> Tensor:
91
+ return self.ls1(self.attn(self.norm1(x)))
92
+
93
+ def ffn_residual_func(x: Tensor) -> Tensor:
94
+ return self.ls2(self.mlp(self.norm2(x)))
95
+
96
+ if self.training and self.sample_drop_ratio > 0.1:
97
+ # the overhead is compensated only for a drop path rate larger than 0.1
98
+ x = drop_add_residual_stochastic_depth(
99
+ x,
100
+ residual_func=attn_residual_func,
101
+ sample_drop_ratio=self.sample_drop_ratio,
102
+ )
103
+ x = drop_add_residual_stochastic_depth(
104
+ x,
105
+ residual_func=ffn_residual_func,
106
+ sample_drop_ratio=self.sample_drop_ratio,
107
+ )
108
+ elif self.training and self.sample_drop_ratio > 0.0:
109
+ x = x + self.drop_path1(attn_residual_func(x))
110
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
111
+ else:
112
+ x = x + attn_residual_func(x)
113
+ x = x + ffn_residual_func(x)
114
+ return x
115
+
116
+
117
+ def drop_add_residual_stochastic_depth(
118
+ x: Tensor,
119
+ residual_func: Callable[[Tensor], Tensor],
120
+ sample_drop_ratio: float = 0.0,
121
+ ) -> Tensor:
122
+ # 1) extract subset using permutation
123
+ b, n, d = x.shape
124
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
125
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
126
+ x_subset = x[brange]
127
+
128
+ # 2) apply residual_func to get residual
129
+ residual = residual_func(x_subset)
130
+
131
+ x_flat = x.flatten(1)
132
+ residual = residual.flatten(1)
133
+
134
+ residual_scale_factor = b / sample_subset_size
135
+
136
+ # 3) add the residual
137
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
138
+ return x_plus_residual.view_as(x)
139
+
140
+
141
+ def get_branges_scales(x, sample_drop_ratio=0.0):
142
+ b, n, d = x.shape
143
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
144
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
145
+ residual_scale_factor = b / sample_subset_size
146
+ return brange, residual_scale_factor
147
+
148
+
149
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
150
+ if scaling_vector is None:
151
+ x_flat = x.flatten(1)
152
+ residual = residual.flatten(1)
153
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
154
+ else:
155
+ x_plus_residual = scaled_index_add(
156
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
157
+ )
158
+ return x_plus_residual
159
+
160
+
161
+ attn_bias_cache: Dict[Tuple, Any] = {}
162
+
163
+
164
+ def get_attn_bias_and_cat(x_list, branges=None):
165
+ """
166
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
167
+ """
168
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
169
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
170
+ if all_shapes not in attn_bias_cache.keys():
171
+ seqlens = []
172
+ for b, x in zip(batch_sizes, x_list):
173
+ for _ in range(b):
174
+ seqlens.append(x.shape[1])
175
+ attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens(seqlens)
176
+ attn_bias._batch_sizes = batch_sizes
177
+ attn_bias_cache[all_shapes] = attn_bias
178
+
179
+ if branges is not None:
180
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
181
+ else:
182
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
183
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
184
+
185
+ return attn_bias_cache[all_shapes], cat_tensors
186
+
187
+
188
+ def drop_add_residual_stochastic_depth_list(
189
+ x_list: List[Tensor],
190
+ residual_func: Callable[[Tensor, Any], Tensor],
191
+ sample_drop_ratio: float = 0.0,
192
+ scaling_vector=None,
193
+ ) -> Tensor:
194
+ # 1) generate random set of indices for dropping samples in the batch
195
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
196
+ branges = [s[0] for s in branges_scales]
197
+ residual_scale_factors = [s[1] for s in branges_scales]
198
+
199
+ # 2) get attention bias and index+concat the tensors
200
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
201
+
202
+ # 3) apply residual_func to get residual, and split the result
203
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
204
+
205
+ outputs = []
206
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
207
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
208
+ return outputs
209
+
210
+
211
+ class NestedTensorBlock(Block):
212
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
213
+ """
214
+ x_list contains a list of tensors to nest together and run
215
+ """
216
+ assert isinstance(self.attn, MemEffAttention)
217
+
218
+ if self.training and self.sample_drop_ratio > 0.0:
219
+
220
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
221
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
222
+
223
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
224
+ return self.mlp(self.norm2(x))
225
+
226
+ x_list = drop_add_residual_stochastic_depth_list(
227
+ x_list,
228
+ residual_func=attn_residual_func,
229
+ sample_drop_ratio=self.sample_drop_ratio,
230
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
231
+ )
232
+ x_list = drop_add_residual_stochastic_depth_list(
233
+ x_list,
234
+ residual_func=ffn_residual_func,
235
+ sample_drop_ratio=self.sample_drop_ratio,
236
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
237
+ )
238
+ return x_list
239
+ else:
240
+
241
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
242
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
243
+
244
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
245
+ return self.ls2(self.mlp(self.norm2(x)))
246
+
247
+ attn_bias, x = get_attn_bias_and_cat(x_list)
248
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
249
+ x = x + ffn_residual_func(x)
250
+ return attn_bias.split(x)
251
+
252
+ def forward(self, x_or_x_list):
253
+ if isinstance(x_or_x_list, Tensor):
254
+ return super().forward(x_or_x_list)
255
+ elif isinstance(x_or_x_list, list):
256
+ if not XFORMERS_AVAILABLE:
257
+ raise AssertionError("xFormers is required for using nested tensors")
258
+ return self.forward_nested(x_or_x_list)
259
+ else:
260
+ raise AssertionError
models/encoders/dinov2/layers/dino_head.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.init import trunc_normal_
9
+ from torch.nn.utils import weight_norm
10
+
11
+
12
+ class DINOHead(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_dim,
16
+ out_dim,
17
+ use_bn=False,
18
+ nlayers=3,
19
+ hidden_dim=2048,
20
+ bottleneck_dim=256,
21
+ mlp_bias=True,
22
+ ):
23
+ super().__init__()
24
+ nlayers = max(nlayers, 1)
25
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
26
+ self.apply(self._init_weights)
27
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
28
+ self.last_layer.weight_g.data.fill_(1)
29
+
30
+ def _init_weights(self, m):
31
+ if isinstance(m, nn.Linear):
32
+ trunc_normal_(m.weight, std=0.02)
33
+ if isinstance(m, nn.Linear) and m.bias is not None:
34
+ nn.init.constant_(m.bias, 0)
35
+
36
+ def forward(self, x):
37
+ x = self.mlp(x)
38
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
39
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
40
+ x = self.last_layer(x)
41
+ return x
42
+
43
+
44
+ def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
45
+ if nlayers == 1:
46
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
47
+ else:
48
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
49
+ if use_bn:
50
+ layers.append(nn.BatchNorm1d(hidden_dim))
51
+ layers.append(nn.GELU())
52
+ for _ in range(nlayers - 2):
53
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
54
+ if use_bn:
55
+ layers.append(nn.BatchNorm1d(hidden_dim))
56
+ layers.append(nn.GELU())
57
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
58
+ return nn.Sequential(*layers)
models/encoders/dinov2/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
models/encoders/dinov2/layers/layer_scale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ init_values: Union[float, Tensor] = 1e-5,
20
+ inplace: bool = False,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.inplace = inplace
24
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
models/encoders/dinov2/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
models/encoders/dinov2/layers/patch_embed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ _, _, H, W = x.shape
70
+ patch_H, patch_W = self.patch_size
71
+
72
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74
+
75
+ x = self.proj(x) # B C H W
76
+ H, W = x.size(2), x.size(3)
77
+ x = x.flatten(2).transpose(1, 2) # B HW C
78
+ x = self.norm(x)
79
+ if not self.flatten_embedding:
80
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81
+ return x
82
+
83
+ def flops(self) -> float:
84
+ Ho, Wo = self.patches_resolution
85
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86
+ if self.norm is not None:
87
+ flops += Ho * Wo * self.embed_dim
88
+ return flops
models/encoders/dinov2/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from typing import Callable, Optional
8
+ import warnings
9
+
10
+ from torch import Tensor, nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ hidden_features: Optional[int] = None,
19
+ out_features: Optional[int] = None,
20
+ act_layer: Callable[..., nn.Module] = None,
21
+ drop: float = 0.0,
22
+ bias: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ x12 = self.w12(x)
32
+ x1, x2 = x12.chunk(2, dim=-1)
33
+ hidden = F.silu(x1) * x2
34
+ return self.w3(hidden)
35
+
36
+
37
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38
+ try:
39
+ if XFORMERS_ENABLED:
40
+ from xformers.ops import SwiGLU
41
+
42
+ XFORMERS_AVAILABLE = True
43
+ warnings.warn("xFormers is available (SwiGLU)")
44
+ else:
45
+ warnings.warn("xFormers is disabled (SwiGLU)")
46
+ raise ImportError
47
+ except ImportError:
48
+ SwiGLU = SwiGLUFFN
49
+ XFORMERS_AVAILABLE = False
50
+
51
+ warnings.warn("xFormers is not available (SwiGLU)")
52
+
53
+
54
+ class SwiGLUFFNFused(SwiGLU):
55
+ def __init__(
56
+ self,
57
+ in_features: int,
58
+ hidden_features: Optional[int] = None,
59
+ out_features: Optional[int] = None,
60
+ act_layer: Callable[..., nn.Module] = None,
61
+ drop: float = 0.0,
62
+ bias: bool = True,
63
+ ) -> None:
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67
+ super().__init__(
68
+ in_features=in_features,
69
+ hidden_features=hidden_features,
70
+ out_features=out_features,
71
+ bias=bias,
72
+ )
models/encoders/dinov2/models/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ from . import vision_transformer as vits
9
+
10
+
11
+ logger = logging.getLogger("dinov2")
12
+
13
+
14
+ def build_model(args, only_teacher=False, img_size=224):
15
+ args.arch = args.arch.removesuffix("_memeff")
16
+ if "vit" in args.arch:
17
+ vit_kwargs = dict(
18
+ img_size=img_size,
19
+ patch_size=args.patch_size,
20
+ init_values=args.layerscale,
21
+ ffn_layer=args.ffn_layer,
22
+ block_chunks=args.block_chunks,
23
+ qkv_bias=args.qkv_bias,
24
+ proj_bias=args.proj_bias,
25
+ ffn_bias=args.ffn_bias,
26
+ num_register_tokens=args.num_register_tokens,
27
+ interpolate_offset=args.interpolate_offset,
28
+ interpolate_antialias=args.interpolate_antialias,
29
+ )
30
+ teacher = vits.__dict__[args.arch](**vit_kwargs)
31
+ if only_teacher:
32
+ return teacher, teacher.embed_dim
33
+ student = vits.__dict__[args.arch](
34
+ **vit_kwargs,
35
+ drop_path_rate=args.drop_path_rate,
36
+ drop_path_uniform=args.drop_path_uniform,
37
+ )
38
+ embed_dim = student.embed_dim
39
+ return student, teacher, embed_dim
40
+
41
+
42
+ def build_model_from_cfg(cfg, only_teacher=False):
43
+ return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
models/encoders/dinov2/models/vision_transformer.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from DINOv2 (https://github.com/facebookresearch/dinov2)
2
+ # ------------------------------------------------------------------------
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ #
5
+ # This source code is licensed under the Apache License, Version 2.0
6
+ # found in the LICENSE file in the root directory of this source tree.
7
+
8
+ # References:
9
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
10
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
11
+
12
+ from functools import partial
13
+ import math
14
+ import logging
15
+ from typing import Sequence, Tuple, Union, Callable
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+ from torch.nn.init import trunc_normal_
21
+
22
+ import copy
23
+
24
+ from models.encoders.dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
25
+
26
+
27
+ logger = logging.getLogger("dinov2")
28
+
29
+
30
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
31
+ if not depth_first and include_root:
32
+ fn(module=module, name=name)
33
+ for child_name, child_module in module.named_children():
34
+ child_name = ".".join((name, child_name)) if name else child_name
35
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
36
+ if depth_first and include_root:
37
+ fn(module=module, name=name)
38
+ return module
39
+
40
+
41
+ class BlockChunk(nn.ModuleList):
42
+ def forward(self, x):
43
+ for b in self:
44
+ x = b(x)
45
+ return x
46
+
47
+
48
+ class DinoVisionTransformer(nn.Module):
49
+ def __init__(
50
+ self,
51
+ img_size=224,
52
+ patch_size=16,
53
+ in_chans=3,
54
+ embed_dim=768,
55
+ depth=12,
56
+ num_heads=12,
57
+ mlp_ratio=4.0,
58
+ qkv_bias=True,
59
+ ffn_bias=True,
60
+ proj_bias=True,
61
+ drop_path_rate=0.0,
62
+ drop_path_uniform=False,
63
+ init_values=None, # for layerscale: None or 0 => no layerscale
64
+ embed_layer=PatchEmbed,
65
+ act_layer=nn.GELU,
66
+ block_fn=Block,
67
+ ffn_layer="mlp",
68
+ block_chunks=1,
69
+ num_register_tokens=0,
70
+ interpolate_antialias=False,
71
+ interpolate_offset=0.1,
72
+ num_additional_blocks = 0,
73
+ ):
74
+ """
75
+ Args:
76
+ img_size (int, tuple): input image size
77
+ patch_size (int, tuple): patch size
78
+ in_chans (int): number of input channels
79
+ embed_dim (int): embedding dimension
80
+ depth (int): depth of transformer
81
+ num_heads (int): number of attention heads
82
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
83
+ qkv_bias (bool): enable bias for qkv if True
84
+ proj_bias (bool): enable bias for proj in attn if True
85
+ ffn_bias (bool): enable bias for ffn if True
86
+ drop_path_rate (float): stochastic depth rate
87
+ drop_path_uniform (bool): apply uniform drop rate across blocks
88
+ weight_init (str): weight init scheme
89
+ init_values (float): layer-scale init values
90
+ embed_layer (nn.Module): patch embedding layer
91
+ act_layer (nn.Module): MLP activation layer
92
+ block_fn (nn.Module): transformer block class
93
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
94
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
95
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
96
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
97
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
98
+ """
99
+ super().__init__()
100
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
101
+
102
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
103
+ self.num_tokens = 1
104
+ self.n_blocks = depth
105
+ self.num_heads = num_heads
106
+ self.patch_size = patch_size
107
+ self.num_register_tokens = num_register_tokens
108
+ self.interpolate_antialias = interpolate_antialias
109
+ self.interpolate_offset = interpolate_offset
110
+ self.img_size = img_size
111
+
112
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
113
+ num_patches = self.patch_embed.num_patches
114
+
115
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
116
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
117
+ assert num_register_tokens >= 0
118
+ self.register_tokens = (
119
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
120
+ )
121
+
122
+ if drop_path_uniform is True:
123
+ dpr = [drop_path_rate] * depth
124
+ else:
125
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
126
+
127
+ if ffn_layer == "mlp":
128
+ logger.info("using MLP layer as FFN")
129
+ ffn_layer = Mlp
130
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
131
+ logger.info("using SwiGLU layer as FFN")
132
+ ffn_layer = SwiGLUFFNFused
133
+ elif ffn_layer == "identity":
134
+ logger.info("using Identity layer as FFN")
135
+
136
+ def f(*args, **kwargs):
137
+ return nn.Identity()
138
+
139
+ ffn_layer = f
140
+ else:
141
+ raise NotImplementedError
142
+
143
+ blocks_list = [
144
+ block_fn(
145
+ dim=embed_dim,
146
+ num_heads=num_heads,
147
+ mlp_ratio=mlp_ratio,
148
+ qkv_bias=qkv_bias,
149
+ proj_bias=proj_bias,
150
+ ffn_bias=ffn_bias,
151
+ drop_path=dpr[i],
152
+ norm_layer=norm_layer,
153
+ act_layer=act_layer,
154
+ ffn_layer=ffn_layer,
155
+ init_values=init_values,
156
+ )
157
+ for i in range(depth)
158
+ ]
159
+ if block_chunks > 0:
160
+ self.chunked_blocks = True
161
+ chunked_blocks = []
162
+ chunksize = depth // block_chunks
163
+ for i in range(0, depth, chunksize):
164
+ # this is to keep the block index consistent if we chunk the block list
165
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
166
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
167
+ else:
168
+ self.chunked_blocks = False
169
+ self.blocks = nn.ModuleList(blocks_list)
170
+
171
+ if num_additional_blocks > 0:
172
+ assert not self.chunked_blocks
173
+ self.additional_blocks = copy.deepcopy(self.blocks[:num_additional_blocks])
174
+
175
+ self.norm = norm_layer(embed_dim)
176
+ self.head = nn.Identity()
177
+
178
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
179
+
180
+ self.init_weights()
181
+
182
+ def init_weights(self):
183
+ trunc_normal_(self.pos_embed, std=0.02)
184
+ nn.init.normal_(self.cls_token, std=1e-6)
185
+ if self.register_tokens is not None:
186
+ nn.init.normal_(self.register_tokens, std=1e-6)
187
+ named_apply(init_weights_vit_timm, self)
188
+
189
+ def interpolate_pos_encoding(self, x, w, h, with_cls_token = True):
190
+ previous_dtype = x.dtype
191
+ # npatch = x.shape[1] - 1 if with_cls_token else x.shape[1]
192
+ N = self.pos_embed.shape[1] - 1
193
+ # if npatch == N and w == h:
194
+ # return self.pos_embed
195
+ pos_embed = self.pos_embed.float()
196
+ class_pos_embed = pos_embed[:, 0]
197
+ patch_pos_embed = pos_embed[:, 1:]
198
+ dim = x.shape[-1]
199
+ w0 = w // self.patch_size
200
+ h0 = h // self.patch_size
201
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
202
+ assert N == M * M
203
+ kwargs = {}
204
+ if self.interpolate_offset:
205
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
206
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
207
+ sx = float(w0 + self.interpolate_offset) / M
208
+ sy = float(h0 + self.interpolate_offset) / M
209
+ kwargs["scale_factor"] = (sx, sy)
210
+ else:
211
+ # Simply specify an output size instead of a scale factor
212
+ kwargs["size"] = (w0, h0)
213
+ patch_pos_embed = nn.functional.interpolate(
214
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
215
+ mode="bicubic",
216
+ antialias=self.interpolate_antialias,
217
+ **kwargs,
218
+ )
219
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
220
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
221
+ if with_cls_token:
222
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
223
+ else:
224
+ return patch_pos_embed.to(previous_dtype)
225
+
226
+ def interpolate_pos_encoding2(self, x, input_size, feature_h, feature_w):
227
+ previous_dtype = x.dtype
228
+ N = self.pos_embed.shape[1] - 1
229
+
230
+ pos_embed = self.pos_embed.float()
231
+ patch_pos_embed = pos_embed[:, 1:]
232
+ dim = x.shape[-1]
233
+ w0 = input_size // self.patch_size
234
+ h0 = input_size // self.patch_size
235
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
236
+ assert N == M * M
237
+ kwargs = {}
238
+ if self.interpolate_offset:
239
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
240
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
241
+ sx = float(w0 + self.interpolate_offset) / M
242
+ sy = float(h0 + self.interpolate_offset) / M
243
+ kwargs["scale_factor"] = (sx, sy)
244
+ else:
245
+ # Simply specify an output size instead of a scale factor
246
+ kwargs["size"] = (w0, h0)
247
+ patch_pos_embed = nn.functional.interpolate(
248
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
249
+ mode="bicubic",
250
+ antialias=self.interpolate_antialias,
251
+ **kwargs,
252
+ )
253
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
254
+ patch_pos_embed = patch_pos_embed[...,:feature_h,:feature_w].permute(0, 2, 3, 1).reshape(1, -1, dim)
255
+ return patch_pos_embed.to(previous_dtype)
256
+
257
+ def interpolate_pos_encoding3(self, target_size):
258
+ # previous_dtype = x.dtype
259
+ N = self.pos_embed.shape[1] - 1
260
+ pos_embed = self.pos_embed.float()
261
+ patch_pos_embed = pos_embed[:, 1:]
262
+ dim = self.embed_dim
263
+ w0 = target_size
264
+ h0 = target_size
265
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
266
+ assert N == M * M
267
+ kwargs = {}
268
+ if self.interpolate_offset:
269
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
270
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
271
+ sx = float(w0 + self.interpolate_offset) / M
272
+ sy = float(h0 + self.interpolate_offset) / M
273
+ kwargs["scale_factor"] = (sx, sy)
274
+ else:
275
+ # Simply specify an output size instead of a scale factor
276
+ kwargs["size"] = (w0, h0)
277
+ patch_pos_embed = nn.functional.interpolate(
278
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
279
+ mode="bicubic",
280
+ antialias=self.interpolate_antialias,
281
+ **kwargs,
282
+ )
283
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
284
+ patch_pos_embed = patch_pos_embed.squeeze(0).permute(1,2,0)
285
+
286
+ return patch_pos_embed
287
+
288
+ def prepare_tokens_with_masks(self, x, masks=None, with_pos_embed = True):
289
+ B, nc, w, h = x.shape
290
+ x = self.patch_embed(x)
291
+ if masks is not None:
292
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
293
+
294
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
295
+
296
+ if with_pos_embed:
297
+ x = x + self.interpolate_pos_encoding(x, w, h)
298
+
299
+ if self.register_tokens is not None:
300
+ x = torch.cat(
301
+ (
302
+ x[:, :1],
303
+ self.register_tokens.expand(x.shape[0], -1, -1),
304
+ x[:, 1:],
305
+ ),
306
+ dim=1,
307
+ )
308
+
309
+ return x
310
+
311
+ def prepare_tokens_with_masks2(self, x, masks):
312
+ assert masks.ndim == 3
313
+ B, nc, w, h = x.shape
314
+ token_lens = masks.flatten(1).sum(1).tolist()
315
+ patched_x = x.view(B, 3, w//self.patch_size, self.patch_size, h//self.patch_size, self.patch_size)
316
+ patched_x = patched_x.permute(0, 2, 4, 1, 3, 5)
317
+ x = self.patch_embed.norm(self.patch_embed.proj(patched_x[masks]).flatten(1))
318
+ pos_embed = self.interpolate_pos_encoding(x, w, h, with_cls_token=False).repeat(B, 1, 1)
319
+ x = x + pos_embed[masks.flatten(1)]
320
+
321
+ cr_token = self.cls_token.view(1,-1) + self.pos_embed.float()[:,0].view(1,-1)
322
+ if self.register_tokens is not None:
323
+ cr_token = torch.cat([cr_token, self.register_tokens.view(self.num_register_tokens, -1)])
324
+
325
+ # x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
326
+
327
+ # if self.register_tokens is not None:
328
+ # x = torch.cat(
329
+ # (
330
+ # x[:, :1],
331
+ # self.register_tokens.expand(x.shape[0], -1, -1),
332
+ # x[:, 1:],
333
+ # ),
334
+ # dim=1,
335
+ # )
336
+
337
+ return x, token_lens, cr_token
338
+
339
+ def forward_features_list(self, x_list, masks_list):
340
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
341
+ for blk in self.blocks:
342
+ x = blk(x)
343
+
344
+ all_x = x
345
+ output = []
346
+ for x, masks in zip(all_x, masks_list):
347
+ x_norm = self.norm(x)
348
+ output.append(
349
+ {
350
+ "x_norm_clstoken": x_norm[:, 0],
351
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
352
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
353
+ "x_prenorm": x,
354
+ "masks": masks,
355
+ }
356
+ )
357
+ return output
358
+
359
+ def forward_features(self, x, masks=None):
360
+ if isinstance(x, list):
361
+ return self.forward_features_list(x, masks)
362
+
363
+ x = self.prepare_tokens_with_masks(x, masks)
364
+
365
+ for blk in self.blocks:
366
+ x = blk(x)
367
+
368
+ x_norm = self.norm(x)
369
+ return {
370
+ "x_norm_clstoken": x_norm[:, 0],
371
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
372
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
373
+ "x_prenorm": x,
374
+ "masks": masks,
375
+ }
376
+
377
+ def forward_specific_layers(self, x, start=0, end=None, norm=True):
378
+ assert not self.chunked_blocks
379
+ if end is None:
380
+ end = len(self.blocks)
381
+ for blk in self.blocks[start:end]:
382
+ x = blk(x)
383
+ out = x[:, 1 + self.num_register_tokens :]
384
+ if norm:
385
+ out = self.norm(out)
386
+ return x, out
387
+
388
+ def forward_specific_layers_list(self, x_list, start=0, end=None, norm=True, get_feature=True):
389
+ assert not self.chunked_blocks
390
+ if end is None:
391
+ end = len(self.blocks)
392
+ for blk in self.blocks[start:end]:
393
+ x_list = blk(x_list)
394
+
395
+ if get_feature:
396
+ out_list = [x[:, 1 + self.num_register_tokens:, :] for x in x_list]
397
+ if norm:
398
+ out_list = [self.norm(out) for out in out_list]
399
+ return x_list, out_list
400
+ else:
401
+ return x_list
402
+
403
+ def forward_additional_layers_list(self, x_list, start=0, end=None, norm=True, get_feature=True):
404
+ assert not self.chunked_blocks
405
+ if end is None:
406
+ end = len(self.additional_blocks)
407
+ for blk in self.additional_blocks[start:end]:
408
+ x_list = blk(x_list)
409
+
410
+ if get_feature:
411
+ out_list = [x[:, 1 + self.num_register_tokens:, :] for x in x_list]
412
+ if norm:
413
+ out_list = [self.norm(out) for out in out_list]
414
+ return x_list, out_list
415
+ else:
416
+ return x_list
417
+
418
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
419
+ x = self.prepare_tokens_with_masks(x)
420
+ # If n is an int, take the n last blocks. If it's a list, take them
421
+ output, total_block_len = [], len(self.blocks)
422
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
423
+ for i, blk in enumerate(self.blocks):
424
+ x = blk(x)
425
+ if i in blocks_to_take:
426
+ output.append(x)
427
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
428
+ return output
429
+
430
+ def _get_intermediate_layers_chunked(self, x, n=1):
431
+ x = self.prepare_tokens_with_masks(x)
432
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
433
+ # If n is an int, take the n last blocks. If it's a list, take them
434
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
435
+ for block_chunk in self.blocks:
436
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
437
+ x = blk(x)
438
+ if i in blocks_to_take:
439
+ output.append(x)
440
+ i += 1
441
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
442
+ return output
443
+
444
+ def get_intermediate_layers(
445
+ self,
446
+ x: torch.Tensor,
447
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
448
+ reshape: bool = False,
449
+ return_class_token: bool = False,
450
+ norm=True,
451
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
452
+ if self.chunked_blocks:
453
+ outputs = self._get_intermediate_layers_chunked(x, n)
454
+ else:
455
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
456
+ if norm:
457
+ outputs = [self.norm(out) for out in outputs]
458
+ class_tokens = [out[:, 0] for out in outputs]
459
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
460
+ if reshape:
461
+ B, _, w, h = x.shape
462
+ outputs = [
463
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
464
+ for out in outputs
465
+ ]
466
+ if return_class_token:
467
+ return tuple(zip(outputs, class_tokens))
468
+ return tuple(outputs)
469
+
470
+ def forward(self, *args, is_training=False, **kwargs):
471
+ ret = self.forward_features(*args, **kwargs)
472
+ if is_training:
473
+ return ret
474
+ else:
475
+ return self.head(ret["x_norm_clstoken"])
476
+
477
+
478
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
479
+ """ViT weight initialization, original timm impl (for reproducibility)"""
480
+ if isinstance(module, nn.Linear):
481
+ trunc_normal_(module.weight, std=0.02)
482
+ if module.bias is not None:
483
+ nn.init.zeros_(module.bias)
484
+
485
+
486
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
487
+ model = DinoVisionTransformer(
488
+ patch_size=patch_size,
489
+ embed_dim=384,
490
+ depth=12,
491
+ num_heads=6,
492
+ mlp_ratio=4,
493
+ block_fn=partial(Block, attn_class=MemEffAttention),
494
+ num_register_tokens=num_register_tokens,
495
+ **kwargs,
496
+ )
497
+ return model
498
+
499
+
500
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
501
+ model = DinoVisionTransformer(
502
+ patch_size=patch_size,
503
+ embed_dim=768,
504
+ depth=12,
505
+ num_heads=12,
506
+ mlp_ratio=4,
507
+ block_fn=partial(Block, attn_class=MemEffAttention),
508
+ num_register_tokens=num_register_tokens,
509
+ **kwargs,
510
+ )
511
+ return model
512
+
513
+
514
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
515
+ model = DinoVisionTransformer(
516
+ patch_size=patch_size,
517
+ embed_dim=1024,
518
+ depth=24,
519
+ num_heads=16,
520
+ mlp_ratio=4,
521
+ block_fn=partial(Block, attn_class=MemEffAttention),
522
+ num_register_tokens=num_register_tokens,
523
+ **kwargs,
524
+ )
525
+ return model
526
+
527
+
528
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
529
+ """
530
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
531
+ """
532
+ model = DinoVisionTransformer(
533
+ patch_size=patch_size,
534
+ embed_dim=1536,
535
+ depth=40,
536
+ num_heads=24,
537
+ mlp_ratio=4,
538
+ block_fn=partial(Block, attn_class=MemEffAttention),
539
+ num_register_tokens=num_register_tokens,
540
+ **kwargs,
541
+ )
542
+ return model
models/human_models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .smpl_models import SMPL_Layer, smpl_gendered
models/human_models/smpl_models.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import smplx
4
+ import numpy as np
5
+ import pickle
6
+ import os.path as osp
7
+ from configs.paths import smpl_model_path
8
+
9
+
10
+ class SMPL_Layer(nn.Module):
11
+ def __init__(self, model_path, with_genders = True, **kwargs):
12
+ """
13
+ Extension of the SMPL Layer with gendered inputs.
14
+ """
15
+ super().__init__()
16
+ smpl_kwargs = {'create_global_orient': False, 'create_body_pose': False,
17
+ 'create_betas': False, 'create_transl': False}
18
+ smpl_kwargs.update(kwargs)
19
+ self.with_genders = with_genders
20
+ if self.with_genders:
21
+ self.layer_n = smplx.create(model_path, 'smpl', gender='neutral', **smpl_kwargs)
22
+ self.layer_m = smplx.create(model_path, 'smpl', gender='male', **smpl_kwargs)
23
+ self.layer_f = smplx.create(model_path, 'smpl', gender='female', **smpl_kwargs)
24
+ self.layers = {'neutral': self.layer_n, 'male': self.layer_m, 'female': self.layer_f}
25
+ else:
26
+ self.layer_n = smplx.create(model_path, 'smpl', gender='neutral', **smpl_kwargs)
27
+ self.layers = {'neutral': self.layer_n}
28
+
29
+ self.vertex_num = 6890
30
+ self.faces = self.layer_n.faces
31
+
32
+ self.body_vertex_idx = np.load(osp.join(model_path, 'smpl', 'body_verts_smpl.npy'))
33
+ self.smpl2h36m_regressor = np.load(osp.join(model_path, 'smpl', 'J_regressor_h36m_correct.npy'))
34
+
35
+
36
+ def forward_single_gender(self, poses, betas, gender='neutral'):
37
+ bs = poses.shape[0]
38
+ if poses.ndim == 2:
39
+ poses = poses.view(bs, -1, 3)
40
+
41
+ assert poses.shape[1] == 24
42
+ pose_params = {'global_orient': poses[:, :1, :],
43
+ 'body_pose': poses[:, 1:, :]}
44
+
45
+ smpl_output = self.layers[gender](betas=betas, **pose_params)
46
+ return smpl_output.vertices, smpl_output.joints
47
+
48
+ def forward(self, poses, betas, genders = None):
49
+ bs = poses.shape[0]
50
+ assert poses.shape[0] == betas.shape[0]
51
+ if genders is None:
52
+ return self.forward_single_gender(poses, betas)
53
+ else:
54
+ assert len(genders) == bs
55
+ assert set(genders) <= {'male', 'female'}
56
+ assert self.with_genders
57
+
58
+ male_idx = [i for i, gender in enumerate(genders) if gender == 'male']
59
+ if len(male_idx) == bs:
60
+ return self.forward_single_gender(poses, betas, gender='male')
61
+ elif len(male_idx) == 0:
62
+ return self.forward_single_gender(poses, betas, gender='female')
63
+ else:
64
+ vertices, joints = self.forward_single_gender(poses, betas, gender='female')
65
+ vertices[male_idx], joints[male_idx] =\
66
+ self.forward_single_gender(poses[male_idx], betas[male_idx], gender='male')
67
+ return vertices, joints
68
+
69
+ smpl_gendered = SMPL_Layer(smpl_model_path, with_genders = True)
models/matcher.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from DAB-DETR (https://github.com/IDEA-Research/DAB-DETR)
2
+
3
+ import torch
4
+ from scipy.optimize import linear_sum_assignment
5
+ from torch import nn
6
+
7
+ from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
8
+
9
+
10
+ class HungarianMatcher(nn.Module):
11
+ """This class computes an assignment between the targets and the predictions of the network
12
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
13
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
14
+ while the others are un-matched (and thus treated as non-objects).
15
+ """
16
+
17
+ def __init__(self,
18
+ cost_conf: float = 1,
19
+ cost_bbox: float = 1,
20
+ cost_giou: float = 1,
21
+ cost_kpts: float = 10,
22
+ j2ds_norm_scale: float = 518,
23
+ ):
24
+ """Creates the matcher
25
+ Params:
26
+ cost_class: This is the relative weight of the classification error in the matching cost
27
+ cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
28
+ cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
29
+ """
30
+ super().__init__()
31
+ self.cost_conf = cost_conf
32
+ self.cost_bbox = cost_bbox
33
+ self.cost_giou = cost_giou
34
+ self.cost_kpts = cost_kpts
35
+ self.j2ds_norm_scale = j2ds_norm_scale
36
+ assert cost_conf != 0 or cost_bbox != 0 or cost_giou != 0 or cost_kpts != 0, "all costs cant be 0"
37
+
38
+ # self.focal_alpha = focal_alpha
39
+
40
+ @torch.no_grad()
41
+ def forward_enc(self, outputs, targets):
42
+ """ Performs the matching
43
+ Params:
44
+ outputs: This is a dict that contains at least these entries:
45
+ "pred_confs": Tensor of flattened confidence score, [total_lens, 1]
46
+ "pred_boxes": Tensor of flattened boxes, [total_lens, 4]
47
+ "lens": num of predictions for each sample in the batch, sum(lens) == total_lens
48
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
49
+ "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
50
+ Returns:
51
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
52
+ - index_i is the indices of the selected predictions (in order)
53
+ - index_j is the indices of the corresponding selected targets (in order)
54
+ For each batch element, it holds:
55
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
56
+ """
57
+ out_conf = outputs['pred_confs']
58
+ out_bbox = outputs["pred_boxes"]
59
+ lens = outputs['lens']
60
+ assert len(lens) == len(targets)
61
+ assert tuple(out_conf.shape) == (sum(lens),1)
62
+ assert tuple(out_bbox.shape) == (sum(lens),4)
63
+
64
+ # Also concat the target labels and boxes
65
+ tgt_bbox = torch.cat([v["boxes"] for v in targets])
66
+
67
+ # Compute the confidence cost.
68
+ alpha = 0.25
69
+ gamma = 2.0
70
+ cost_conf = alpha * ((1 - out_conf) ** gamma) * (-(out_conf + 1e-8).log())
71
+ # cost_conf = -(out_conf+1e-8).log()
72
+
73
+ # Compute the L1 cost between boxes
74
+ cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
75
+
76
+ # Compute the giou cost betwen boxes
77
+ # import ipdb; ipdb.set_trace()
78
+ cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
79
+
80
+ # Final cost matrix
81
+ C = self.cost_conf*cost_conf + self.cost_bbox * cost_bbox + self.cost_giou * cost_giou
82
+ C = C.cpu()
83
+
84
+ sizes = [len(v["boxes"]) for v in targets]
85
+ idx=0
86
+ indices = []
87
+ for i, c in enumerate(C.split(sizes, -1)):
88
+ indices.append(linear_sum_assignment(c[idx:idx+lens[i]]))
89
+ idx += lens[i]
90
+ return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
91
+
92
+ @torch.no_grad()
93
+ def forward(self, outputs, targets):
94
+ """ Performs the matching
95
+ Params:
96
+ outputs: This is a dict that contains at least these entries:
97
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
98
+ "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
99
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
100
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
101
+ objects in the target) containing the class labels
102
+ "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
103
+ Returns:
104
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
105
+ - index_i is the indices of the selected predictions (in order)
106
+ - index_j is the indices of the corresponding selected targets (in order)
107
+ For each batch element, it holds:
108
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
109
+ """
110
+ assert outputs['pred_confs'].shape[0]==len(targets)
111
+ bs, num_queries, _ = outputs["pred_confs"].shape
112
+
113
+ # We flatten to compute the cost matrices in a batch
114
+ out_conf = outputs['pred_confs'].flatten(0,1) # [batch_size * num_queries, 1]
115
+ out_bbox = outputs["pred_boxes"].flatten(0,1) # [batch_size * num_queries, 4]
116
+ out_kpts = outputs['pred_j2ds'][...,:22,:].flatten(2).flatten(0,1) / self.j2ds_norm_scale
117
+
118
+ # Also concat the target labels and boxes
119
+ tgt_bbox = torch.cat([v["boxes"] for v in targets])
120
+ tgt_kpts = torch.cat([v['j2ds'][:,:22,:].flatten(1) for v in targets]) / self.j2ds_norm_scale
121
+ tgt_kpts_mask = torch.cat([v['j2ds_mask'][:,:22,:].flatten(1) for v in targets])
122
+ tgt_kpts_vis_cnt = tgt_kpts_mask.sum(-1)
123
+ assert (torch.all(tgt_kpts_vis_cnt))
124
+
125
+ # Compute the confidence cost.
126
+ alpha = 0.25
127
+ gamma = 2.0
128
+ cost_conf = alpha * ((1 - out_conf) ** gamma) * (-(out_conf + 1e-8).log())
129
+ # cost_conf = -(out_conf+1e-8).log()
130
+
131
+ # Compute the L1 cost between boxes
132
+ cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
133
+
134
+ # Compute the giou cost betwen boxes
135
+ # import ipdb; ipdb.set_trace()
136
+ cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
137
+
138
+ # Compute the mean L1 cost between visible joints
139
+ all_dist = torch.abs(out_kpts[:,None,:] - tgt_kpts[None,:,:])
140
+ mean_dist = (all_dist * tgt_kpts_mask[None,:,:]).sum(-1) / tgt_kpts_vis_cnt[None,:]
141
+ cost_kpts = mean_dist
142
+
143
+ # Final cost matrix
144
+ C = self.cost_conf*cost_conf + self.cost_kpts*cost_kpts + self.cost_bbox * cost_bbox + self.cost_giou * cost_giou
145
+ C = C.view(bs, num_queries, -1).cpu()
146
+
147
+ sizes = [len(v["boxes"]) for v in targets]
148
+ indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
149
+ return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
150
+
151
+
152
+ def build_matcher(args):
153
+ return HungarianMatcher(
154
+ cost_conf=args.set_cost_conf,
155
+ cost_bbox=args.set_cost_bbox,
156
+ cost_giou=args.set_cost_giou,
157
+ cost_kpts=args.set_cost_kpts,
158
+ j2ds_norm_scale=args.input_size
159
+ )
models/position_encoding.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from DAB-DETR (https://github.com/IDEA-Research/DAB-DETR)
2
+ """
3
+ Various positional encodings for the transformer.
4
+ """
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+
9
+ from utils.misc import NestedTensor
10
+
11
+
12
+ def position_encoding_xy(pos_x, pos_y, embedding_dim, temperature = 20, scale = 2*math.pi):
13
+ assert embedding_dim % 2 == 0
14
+ assert pos_x.ndim == 1 and pos_y.ndim == 1
15
+ dim_t = torch.arange(embedding_dim // 2, dtype=torch.float32, device=pos_x.device)
16
+ dim_t = temperature ** (2 * (dim_t // 2) / (embedding_dim // 2))
17
+ x_embed = pos_x * scale
18
+ y_embed = pos_y * scale
19
+ pos_x = x_embed[:, None] / dim_t
20
+ pos_y = y_embed[:, None] / dim_t
21
+ pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
22
+ pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
23
+ pos = torch.cat([pos_y,pos_x], dim=1)
24
+ return pos
25
+
26
+ class PositionEmbeddingSine(nn.Module):
27
+ """
28
+ This is a more standard version of the position embedding, very similar to the one
29
+ used by the Attention is all you need paper, generalized to work on images.
30
+ """
31
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
32
+ super().__init__()
33
+ self.num_pos_feats = num_pos_feats
34
+ self.temperature = temperature
35
+ self.normalize = normalize
36
+ if scale is not None and normalize is False:
37
+ raise ValueError("normalize should be True if scale is passed")
38
+ if scale is None:
39
+ scale = 2 * math.pi
40
+ self.scale = scale
41
+
42
+ def forward(self, tensor_list: NestedTensor):
43
+ x = tensor_list.tensors
44
+ mask = tensor_list.mask
45
+ assert mask is not None
46
+ not_mask = ~mask
47
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
48
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
49
+ if self.normalize:
50
+ eps = 1e-6
51
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
52
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
53
+
54
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
55
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
56
+
57
+ pos_x = x_embed[:, :, :, None] / dim_t
58
+ pos_y = y_embed[:, :, :, None] / dim_t
59
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
60
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
61
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
62
+ return pos
63
+
64
+ class PositionEmbeddingSineHW(nn.Module):
65
+ """
66
+ This is a more standard version of the position embedding, very similar to the one
67
+ used by the Attention is all you need paper, generalized to work on images.
68
+ """
69
+ def __init__(self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None):
70
+ super().__init__()
71
+ self.num_pos_feats = num_pos_feats
72
+ self.temperatureH = temperatureH
73
+ self.temperatureW = temperatureW
74
+ self.normalize = normalize
75
+ if scale is not None and normalize is False:
76
+ raise ValueError("normalize should be True if scale is passed")
77
+ if scale is None:
78
+ scale = 2 * math.pi
79
+ self.scale = scale
80
+
81
+ def forward(self, tensor_list: NestedTensor):
82
+ x = tensor_list.tensors
83
+ mask = tensor_list.mask
84
+ assert mask is not None
85
+ not_mask = ~mask
86
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
87
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
88
+
89
+ # import ipdb; ipdb.set_trace()
90
+
91
+ if self.normalize:
92
+ eps = 1e-6
93
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
94
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
95
+
96
+ dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
97
+ dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats)
98
+ pos_x = x_embed[:, :, :, None] / dim_tx
99
+
100
+ dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
101
+ dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats)
102
+ pos_y = y_embed[:, :, :, None] / dim_ty
103
+
104
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
105
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
106
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
107
+
108
+ # import ipdb; ipdb.set_trace()
109
+
110
+ return pos
111
+
112
+ class PositionEmbeddingLearned(nn.Module):
113
+ """
114
+ Absolute pos embedding, learned.
115
+ """
116
+ def __init__(self, num_pos_feats=256):
117
+ super().__init__()
118
+ self.row_embed = nn.Embedding(50, num_pos_feats)
119
+ self.col_embed = nn.Embedding(50, num_pos_feats)
120
+ self.reset_parameters()
121
+
122
+ def reset_parameters(self):
123
+ nn.init.uniform_(self.row_embed.weight)
124
+ nn.init.uniform_(self.col_embed.weight)
125
+
126
+ def forward(self, tensor_list: NestedTensor):
127
+ x = tensor_list.tensors
128
+ h, w = x.shape[-2:]
129
+ i = torch.arange(w, device=x.device)
130
+ j = torch.arange(h, device=x.device)
131
+ x_emb = self.col_embed(i)
132
+ y_emb = self.row_embed(j)
133
+ pos = torch.cat([
134
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
135
+ y_emb.unsqueeze(1).repeat(1, w, 1),
136
+ ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
137
+ return pos
138
+
139
+
140
+ def build_position_encoding(args):
141
+ N_steps = args.hidden_dim // 2
142
+ if args.position_embedding in ('v2', 'sine'):
143
+ # TODO find a better way of exposing other arguments
144
+ position_embedding = PositionEmbeddingSineHW(
145
+ N_steps,
146
+ temperatureH=args.pe_temperatureH,
147
+ temperatureW=args.pe_temperatureW,
148
+ normalize=True
149
+ )
150
+ elif args.position_embedding in ('v3', 'learned'):
151
+ position_embedding = PositionEmbeddingLearned(N_steps)
152
+ else:
153
+ raise ValueError(f"not supported {args.position_embedding}")
154
+
155
+ return position_embedding
models/sat_model.py ADDED
@@ -0,0 +1,767 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from DAB-DETR (https://github.com/IDEA-Research/DAB-DETR)
2
+
3
+ import os
4
+
5
+ import math
6
+ from math import tan,pi
7
+ from typing import Dict
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+ from torchvision.transforms import Resize
12
+ import numpy as np
13
+ import time
14
+ import random
15
+
16
+ from utils.misc import (NestedTensor, nested_tensor_from_tensor_list,
17
+ accuracy, get_world_size, interpolate,
18
+ is_dist_avail_and_initialized, inverse_sigmoid)
19
+
20
+ from utils.transforms import rot6d_to_axis_angle, img2patch_flat, img2patch, to_zorder
21
+ from utils.map import build_z_map
22
+ from utils import constants
23
+ from configs.paths import smpl_mean_path
24
+
25
+ from models.encoders import build_encoder
26
+ from .matcher import build_matcher
27
+ from .decoder import build_decoder
28
+ from .position_encoding import position_encoding_xy
29
+ from .criterion import SetCriterion
30
+ from .dn_components import prepare_for_cdn, dn_post_process
31
+ import copy
32
+
33
+ from configs.paths import smpl_model_path
34
+ from models.human_models import SMPL_Layer
35
+
36
+
37
+ def _get_clones(module, N):
38
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
39
+
40
+ class Model(nn.Module):
41
+ """ One-stage Multi-person Human Mesh Estimation via Scale-adaptive Tokens """
42
+ def __init__(self, encoder, decoder,
43
+ num_queries,
44
+ input_size,
45
+ sat_cfg = {'use_sat': False},
46
+ dn_cfg = {'use_dn': False},
47
+ train_pos_embed = True,
48
+ aux_loss=True,
49
+ iter_update=True,
50
+ query_dim=4,
51
+ bbox_embed_diff_each_layer=True,
52
+ random_refpoints_xy=False,
53
+ num_poses=24,
54
+ dim_shape=10,
55
+ FOV=pi/3
56
+ ):
57
+ """ Initializes the model.
58
+ Parameters:
59
+ encoder: torch module of the encoder to be used. See ./encoders.
60
+ decoder: torch module of the decoder architecture. See decoder.py
61
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
62
+ DETR can detect in a single image.
63
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
64
+ iter_update: iterative update of boxes
65
+ query_dim: query dimension. 2 for point and 4 for box.
66
+ bbox_embed_diff_each_layer: dont share weights of prediction heads. Default for False. (shared weights.)
67
+ random_refpoints_xy: random init the x,y of anchor boxes and freeze them. (It sometimes helps to improve the performance)
68
+ """
69
+ super().__init__()
70
+
71
+ # ========== Start of common settings =============
72
+ self.input_size = input_size
73
+ hidden_dim = decoder.d_model
74
+ num_dec_layers = decoder.dec_layers
75
+ self.hidden_dim = hidden_dim
76
+ # camera model
77
+ self.focal = input_size/(2*tan(FOV/2))
78
+ self.FOV = FOV
79
+ cam_intrinsics = torch.tensor([[self.focal,0.,self.input_size/2],
80
+ [0.,self.focal,self.input_size/2],
81
+ [0.,0.,1.]])
82
+ self.register_buffer('cam_intrinsics', cam_intrinsics)
83
+ # human model
84
+ self.num_poses = num_poses
85
+ self.dim_shape = dim_shape
86
+ self.human_model = SMPL_Layer(model_path = smpl_model_path, with_genders = False)
87
+ # init params (following multi-hmr)
88
+ smpl_mean_params = np.load(smpl_mean_path, allow_pickle = True)
89
+ self.register_buffer('mean_pose', torch.from_numpy(smpl_mean_params['pose']))
90
+ self.register_buffer('mean_shape', torch.from_numpy(smpl_mean_params['shape']))
91
+ # ========== End of common settings =============
92
+
93
+
94
+ # ========== Start of SAT-encoder settings =============
95
+ self.encoder = encoder
96
+
97
+ self.patch_size = encoder.patch_size
98
+ assert self.patch_size == 14
99
+
100
+ self.use_sat = sat_cfg['use_sat']
101
+ self.sat_cfg = sat_cfg
102
+
103
+ if self.use_sat:
104
+ assert sat_cfg['num_lvls'] >= 2
105
+ assert self.input_size % (self.patch_size<<2) == 0
106
+
107
+ self.feature_size = []
108
+ for lvl in range(sat_cfg['num_lvls']):
109
+ patch_size = self.patch_size<<lvl
110
+ self.feature_size.append(self.input_size / patch_size)
111
+
112
+ # build z_order curve
113
+ z_depth = math.ceil(math.log2(self.feature_size[1]))
114
+ z_map, ys, xs = build_z_map(z_depth)
115
+ self.register_buffer('z_order_map', z_map)
116
+ self.register_buffer('y_coords', ys)
117
+ self.register_buffer('x_coords', xs)
118
+
119
+ self.enc_inter_norm = copy.deepcopy(encoder.norm)
120
+ self.scale_head = MLP(encoder.embed_dim, encoder.embed_dim, 2, 4)
121
+ self.encoder_patch_proj = _get_clones(encoder.patch_embed.proj, 2)
122
+ self.encoder_patch_norm = _get_clones(encoder.patch_embed.norm, 2)
123
+
124
+ if sat_cfg['lvl_embed']:
125
+ # same as level_embed in Deformable-DETR
126
+ self.level_embed = nn.Parameter(torch.Tensor(sat_cfg['num_lvls'],hidden_dim))
127
+ nn.init.normal_(self.level_embed)
128
+ else:
129
+ assert self.input_size % self.patch_size == 0
130
+ self.feature_size = [self.input_size // self.patch_size]
131
+ self.encoder_patch_proj = copy.deepcopy(encoder.patch_embed.proj)
132
+ self.encoder_patch_norm = copy.deepcopy(encoder.patch_embed.norm)
133
+
134
+ # cls_token and register tokens
135
+ encoder_cr_token = self.encoder.cls_token.view(1,-1) + self.encoder.pos_embed.float()[:,0].view(1,-1)
136
+ if self.encoder.register_tokens is not None:
137
+ encoder_cr_token = torch.cat([encoder_cr_token, self.encoder.register_tokens.view(self.encoder.num_register_tokens,-1)], dim=0)
138
+ self.encoder_cr_token = nn.Parameter(encoder_cr_token)
139
+
140
+ self.encoder_pos_embeds = nn.Parameter(self.encoder.interpolate_pos_encoding3(self.feature_size[0]).detach())
141
+ if not train_pos_embed:
142
+ self.encoder_pos_embeds.requires_grad = False
143
+
144
+ self.preprocessed_pos_lvl1 = None
145
+
146
+ # delete unwanted params
147
+ del(self.encoder.mask_token)
148
+ del(self.encoder.pos_embed)
149
+ del(self.encoder.patch_embed)
150
+ del(self.encoder.cls_token)
151
+ del(self.encoder.register_tokens)
152
+ # ========== End of SAT-encoder settings =============
153
+
154
+
155
+
156
+ # ========== Start of decoder settings =============
157
+ self.num_queries = num_queries
158
+ self.decoder = decoder
159
+
160
+ # embed_dim between encoder and decoder can be different
161
+ self.feature_proj = nn.Linear(encoder.embed_dim, hidden_dim)
162
+
163
+ # bbox
164
+ self.bbox_embed_diff_each_layer = bbox_embed_diff_each_layer
165
+ if bbox_embed_diff_each_layer:
166
+ self.bbox_embed = nn.ModuleList([MLP(hidden_dim, hidden_dim, 4, 3) for i in range(num_dec_layers)])
167
+ else:
168
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
169
+ # poses (use 6D rotation)
170
+ self.pose_head = MLP(hidden_dim, hidden_dim, num_poses*6, 6)
171
+ # shape
172
+ self.shape_head = MLP(hidden_dim, hidden_dim, dim_shape, 5)
173
+ # cam_trans
174
+ self.cam_head = MLP(hidden_dim, hidden_dim//2, 3, 3)
175
+ # confidence score
176
+ self.conf_head = nn.Linear(hidden_dim, 1)
177
+ # init prior_prob setting for focal loss
178
+ prior_prob = 0.01
179
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
180
+ self.conf_head.bias.data = torch.ones(1) * bias_value
181
+
182
+ # for iter update
183
+ self.pose_head = _get_clones(self.pose_head, num_dec_layers)
184
+ self.shape_head = _get_clones(self.shape_head, num_dec_layers)
185
+
186
+ # setting query dim (bboxes as queries)
187
+ self.query_dim = query_dim
188
+ assert query_dim == 4
189
+ self.refpoint_embed = nn.Embedding(num_queries, query_dim)
190
+ self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
191
+
192
+ self.random_refpoints_xy = random_refpoints_xy
193
+ if random_refpoints_xy:
194
+ # import ipdb; ipdb.set_trace()
195
+ self.refpoint_embed.weight.data[:, :2].uniform_(0,1)
196
+ self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(self.refpoint_embed.weight.data[:, :2])
197
+ self.refpoint_embed.weight.data[:, :2].requires_grad = False
198
+
199
+ self.aux_loss = aux_loss
200
+ self.iter_update = iter_update
201
+ assert iter_update
202
+ if self.iter_update:
203
+ self.decoder.decoder.bbox_embed = self.bbox_embed
204
+
205
+ assert bbox_embed_diff_each_layer
206
+ if bbox_embed_diff_each_layer:
207
+ for bbox_embed in self.bbox_embed:
208
+ nn.init.constant_(bbox_embed.layers[-1].weight.data, 0)
209
+ nn.init.constant_(bbox_embed.layers[-1].bias.data, 0)
210
+ else:
211
+ nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
212
+ nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
213
+ # ========== End of decoder settings =============
214
+
215
+
216
+ # for dn training
217
+ self.use_dn = dn_cfg['use_dn']
218
+ self.dn_cfg = dn_cfg
219
+ if self.use_dn:
220
+ assert dn_cfg['dn_number'] > 0
221
+ if dn_cfg['tgt_embed_type'] == 'labels':
222
+ self.dn_enc = nn.Embedding(dn_cfg['dn_labelbook_size'], hidden_dim)
223
+ elif dn_cfg['tgt_embed_type'] == 'params':
224
+ self.dn_enc = nn.Linear(num_poses*3 + dim_shape, hidden_dim)
225
+ else:
226
+ raise NotImplementedError
227
+
228
+
229
+ def lvl_pooling(self, tokens):
230
+ assert len(tokens)%4 == 0
231
+ C = tokens.shape[-1]
232
+ return torch.max(tokens.view(-1, 4, C), dim=1)[0]
233
+
234
+ def get_scale_map(self, x_list):
235
+ if self.sat_cfg['use_additional_blocks']:
236
+ x_list = self.encoder.forward_additional_layers_list(x_list, end=self.sat_cfg['get_map_layer'], get_feature=False)
237
+ else:
238
+ x_list = self.encoder.forward_specific_layers_list(x_list, end=self.sat_cfg['get_map_layer'], get_feature=False)
239
+
240
+ cr_token_list = [x[:, :1 + self.encoder.num_register_tokens, :].squeeze(0) for x in x_list]
241
+ x_tokens = torch.cat([x[:, 1 + self.encoder.num_register_tokens:, :].squeeze(0) for x in x_list], dim=0)
242
+ scale_map = self.scale_head(self.enc_inter_norm(x_tokens)).sigmoid()
243
+ return scale_map, cr_token_list, x_tokens
244
+
245
+ def pad_mask(self, mask):
246
+ mask = mask.reshape(-1,4)
247
+ mask[torch.any(mask, dim=1)] = True
248
+ return mask.flatten()
249
+
250
+ def forward_encoder(self, samples, targets, use_gt = False):
251
+ B = len(samples)
252
+ C = self.encoder.embed_dim
253
+ cr_token_list = [self.encoder_cr_token]*len(samples)
254
+
255
+ if not self.use_sat:
256
+ # img2token
257
+ lvl0_feature_hw = [(img.shape[1]//self.patch_size, img.shape[2]//self.patch_size) for img in samples]
258
+ lvl0_token_lens = [h*w for (h,w) in lvl0_feature_hw]
259
+ lvl0_img_patches = torch.cat([img2patch_flat(img, patch_size = self.patch_size)\
260
+ for img in samples], dim=0)
261
+ lvl0_tokens = self.encoder_patch_norm(self.encoder_patch_proj(lvl0_img_patches).flatten(1))
262
+
263
+ # token position information
264
+ full_grids = torch.meshgrid(torch.arange(self.feature_size[0]), torch.arange(self.feature_size[0]), indexing='ij')
265
+ lvl0_pos_y = torch.cat([full_grids[0][:h,:w].flatten() for (h,w) in lvl0_feature_hw]).to(device = lvl0_tokens.device)
266
+ lvl0_pos_x = torch.cat([full_grids[1][:h,:w].flatten() for (h,w) in lvl0_feature_hw]).to(device = lvl0_tokens.device)
267
+
268
+ # pos_embed
269
+ full_pos_embed = self.encoder_pos_embeds
270
+ lvl0_pos_embed = torch.cat([full_pos_embed[:h,:w].flatten(0,1)\
271
+ for (h,w) in lvl0_feature_hw], dim=0)
272
+ lvl0_tokens = lvl0_tokens + lvl0_pos_embed
273
+
274
+ # convert to list for DINOv2 input
275
+ x_list = [torch.cat([cr, lvl0],dim=0).unsqueeze(0)\
276
+ for (cr, lvl0) \
277
+ in zip(cr_token_list, lvl0_tokens.split(lvl0_token_lens))]
278
+
279
+
280
+ lvl0_pos_y_norm = (lvl0_pos_y.to(dtype=lvl0_tokens.dtype) + 0.5) / self.feature_size[0]
281
+ lvl0_pos_x_norm = (lvl0_pos_x.to(dtype=lvl0_tokens.dtype) + 0.5) / self.feature_size[0]
282
+ pos_x_list = list(lvl0_pos_y_norm.split(lvl0_token_lens))
283
+ pos_y_list = list(lvl0_pos_x_norm.split(lvl0_token_lens))
284
+ scale_map_dict = None
285
+ # also create lvl_list for patch visualization
286
+ lvl_list = [torch.zeros_like(pos,dtype=int) for pos in pos_x_list]
287
+
288
+ else:
289
+ lvl1_feature_hw = [(img.shape[1]//(2*self.patch_size), img.shape[2]//(2*self.patch_size)) for img in samples]
290
+ lvl1_token_lens = [h*w for (h,w) in lvl1_feature_hw]
291
+
292
+ lvl1_img_patches_28, lvl1_zorders = [], []
293
+ lvl1_pos_y, lvl1_pos_x = [], []
294
+ lvl1_bids = []
295
+
296
+ for i, img in enumerate(samples):
297
+ z_patches, z_order, pos_y, pos_x = to_zorder(img2patch(img, patch_size = 2*self.patch_size),
298
+ z_order_map = self.z_order_map,
299
+ y_coords = self.y_coords,
300
+ x_coords = self.x_coords)
301
+
302
+ lvl1_img_patches_28.append(z_patches)
303
+
304
+ lvl1_zorders.append(z_order)
305
+ lvl1_pos_y.append(pos_y)
306
+ lvl1_pos_x.append(pos_x)
307
+ lvl1_bids.append(torch.full_like(pos_y, i, dtype=torch.int64))
308
+
309
+
310
+
311
+ lvl1_img_patches_28 = torch.cat(lvl1_img_patches_28, dim=0)
312
+ lvl1_zorders = torch.cat(lvl1_zorders, dim=0)
313
+ lvl1_pos_y = torch.cat(lvl1_pos_y, dim=0)
314
+ lvl1_pos_x = torch.cat(lvl1_pos_x, dim=0)
315
+ lvl1_bids = torch.cat(lvl1_bids, dim=0)
316
+
317
+
318
+
319
+ # (L1, 3, 28, 28)
320
+ assert len(lvl1_img_patches_28) == sum(lvl1_token_lens)
321
+ lvl1_img_patches = F.interpolate(lvl1_img_patches_28, size = (14,14), mode='bilinear', align_corners=False)
322
+ # (L1, 3, 14, 14)
323
+ lvl1_tokens = self.encoder_patch_norm[1](self.encoder_patch_proj[1](lvl1_img_patches).flatten(1))
324
+ # (L1, C)
325
+
326
+
327
+
328
+ assert len(lvl1_pos_y) == len(lvl1_tokens)
329
+ full_pos_embed = self.preprocessed_pos_lvl1 if not self.training\
330
+ else F.interpolate(self.encoder_pos_embeds.unsqueeze(0).permute(0, 3, 1, 2),
331
+ mode="bicubic",
332
+ antialias=self.encoder.interpolate_antialias,
333
+ size = (int(self.feature_size[1]),int(self.feature_size[1]))).squeeze(0).permute(1,2,0)
334
+ lvl1_pos_embed = torch.cat([full_pos_embed[ys,xs]\
335
+ for (ys,xs) in zip(lvl1_pos_y.split(lvl1_token_lens), lvl1_pos_x.split(lvl1_token_lens))], dim=0)
336
+ lvl1_tokens = lvl1_tokens + lvl1_pos_embed
337
+
338
+ # get scale map (flattened)
339
+ x_list = [torch.cat([cr, lvl1],dim=0).unsqueeze(0)\
340
+ for (cr, lvl1) \
341
+ in zip(cr_token_list, lvl1_tokens.split(lvl1_token_lens))]
342
+ scale_map, updated_cr_list, updated_lvl1 = self.get_scale_map(x_list)
343
+ # for visualization
344
+ scale_map_dict = {'scale_map': scale_map, 'lens': lvl1_token_lens, 'hw': lvl1_feature_hw,
345
+ 'pos_y': lvl1_pos_y, 'pos_x': lvl1_pos_x}
346
+
347
+ # get sat masks
348
+ conf_thresh = self.sat_cfg['conf_thresh']
349
+ scale_thresh = self.sat_cfg['scale_thresh']
350
+ if use_gt:
351
+ scale_map = torch.cat([tgt['scale_map'].view(-1,2) for tgt in targets], dim=0)
352
+
353
+ lvl1_valid_mask = scale_map[:,0] > conf_thresh
354
+ lvl1_sat_mask = lvl1_valid_mask & (scale_map[:,1] < scale_thresh)
355
+
356
+ # prepare sat tokens (lvl0)
357
+ lvl0_token_lens = [msk.sum().item()<<2 for msk in lvl1_sat_mask.split(lvl1_token_lens)]
358
+ lvl1_sat_patches_28 = lvl1_img_patches_28[lvl1_sat_mask] # (L0//4, 3, 28, 28)
359
+ lvl0_tokens = self.encoder_patch_norm[0](self.encoder_patch_proj[0](lvl1_sat_patches_28).permute(0, 2, 3, 1).flatten(0,2))
360
+
361
+ assert len(lvl0_tokens) == sum(lvl0_token_lens)
362
+ # lvl0 positions
363
+ lvl0_pos_y, lvl0_pos_x = lvl1_pos_y[lvl1_sat_mask], lvl1_pos_x[lvl1_sat_mask]
364
+ lvl0_pos_y = (lvl0_pos_y<<1)[:,None].repeat(1,4).flatten()
365
+ lvl0_pos_x = (lvl0_pos_x<<1)[:,None].repeat(1,4).flatten()
366
+ lvl0_pos_y[2::4] += 1
367
+ lvl0_pos_y[3::4] += 1
368
+ lvl0_pos_x[1::2] += 1
369
+ assert len(lvl0_pos_x) == len(lvl0_tokens)
370
+
371
+ # lvl0 pos_embed
372
+ full_pos_embed = self.encoder_pos_embeds
373
+ lvl0_pos_embed = torch.cat([full_pos_embed[ys,xs]\
374
+ for (ys,xs) in zip(lvl0_pos_y.split(lvl0_token_lens), lvl0_pos_x.split(lvl0_token_lens))], dim=0)
375
+ lvl0_tokens = lvl0_tokens + lvl0_pos_embed
376
+
377
+
378
+ # update tokens
379
+ x_list = [torch.cat([cr, lvl0],dim=0).unsqueeze(0)\
380
+ for (cr, lvl0) \
381
+ in zip(cr_token_list, lvl0_tokens.split(lvl0_token_lens))]
382
+ x_list = self.encoder.forward_specific_layers_list(x_list, end=self.sat_cfg['get_map_layer'], get_feature=False)
383
+ lvl0_tokens = torch.cat([x[:, 1 + self.encoder.num_register_tokens:, :].squeeze(0) for x in x_list], dim=0)
384
+ assert len(lvl0_pos_x) == len(lvl0_tokens)
385
+ # also update lvl1 and crs
386
+ lvl1_tokens = updated_lvl1
387
+ cr_token_list = updated_cr_list
388
+
389
+
390
+
391
+ if self.sat_cfg['num_lvls'] == 2:
392
+ # drop corresponding lvl1 tokens
393
+ lvl1_keep = ~lvl1_sat_mask
394
+ lvl1_token_lens = [msk.sum().item() for msk in lvl1_keep.split(lvl1_token_lens)]
395
+ lvl1_tokens = lvl1_tokens[lvl1_keep]
396
+ lvl1_pos_y = lvl1_pos_y[lvl1_keep]
397
+ lvl1_pos_x = lvl1_pos_x[lvl1_keep]
398
+
399
+ # normalize positions
400
+ lvl0_pos_y_norm = (lvl0_pos_y.to(dtype=lvl0_tokens.dtype) + 0.5) / self.feature_size[0]
401
+ lvl0_pos_x_norm = (lvl0_pos_x.to(dtype=lvl0_tokens.dtype) + 0.5) / self.feature_size[0]
402
+ lvl1_pos_y_norm = (lvl1_pos_y.to(dtype=lvl1_tokens.dtype) + 0.5) / self.feature_size[1]
403
+ lvl1_pos_x_norm = (lvl1_pos_x.to(dtype=lvl1_tokens.dtype) + 0.5) / self.feature_size[1]
404
+
405
+ # merge all
406
+ x_list = [torch.cat([cr, lvl0, lvl1]).unsqueeze(0) \
407
+ for cr, lvl0, lvl1 \
408
+ in zip(cr_token_list, lvl0_tokens.split(lvl0_token_lens), lvl1_tokens.split(lvl1_token_lens))]
409
+ pos_y_list = [torch.cat([lvl0, lvl1]) \
410
+ for lvl0, lvl1 \
411
+ in zip(lvl0_pos_y_norm.split(lvl0_token_lens), lvl1_pos_y_norm.split(lvl1_token_lens))]
412
+ pos_x_list = [torch.cat([lvl0, lvl1]) \
413
+ for lvl0, lvl1 \
414
+ in zip(lvl0_pos_x_norm.split(lvl0_token_lens), lvl1_pos_x_norm.split(lvl1_token_lens))]
415
+ lvl_list = [torch.cat([torch.zeros_like(lvl0, dtype=int), torch.ones_like(lvl1, dtype=int)]) \
416
+ for lvl0, lvl1 \
417
+ in zip(lvl0_pos_x_norm.split(lvl0_token_lens), lvl1_pos_x_norm.split(lvl1_token_lens))]
418
+
419
+
420
+ else:
421
+ # prune lvl1 correspond to lvl0
422
+ lvl1_valid_mask = self.pad_mask(lvl1_valid_mask)
423
+ lvl1_keep = lvl1_valid_mask & (~lvl1_sat_mask)
424
+ lvl1_to_lvl2 = ~lvl1_valid_mask
425
+
426
+ token_lvls = [lvl0_tokens, lvl1_tokens]
427
+ token_lens_lvls = [lvl0_token_lens, lvl1_token_lens]
428
+ pos_y_lvls = [lvl0_pos_y, lvl1_pos_y]
429
+ pos_x_lvls = [lvl0_pos_x, lvl1_pos_x]
430
+
431
+ to_next_lvl = lvl1_to_lvl2
432
+ keep = lvl1_keep
433
+ lvl_zorders = lvl1_zorders
434
+ lvl_bids = lvl1_bids
435
+ pad_vals = torch.full((3,), -1, dtype=lvl_zorders.dtype, device=lvl_zorders.device)
436
+ for lvl in range(self.sat_cfg['num_lvls']-2):
437
+ if to_next_lvl.sum() == 0:
438
+ break
439
+ next_tokens = self.lvl_pooling(token_lvls[-1][to_next_lvl])
440
+ # next_tokens = torch.max(token_lvls[-1][to_next_lvl].view(-1,4,C), dim=1)[0]
441
+ next_pos_y = pos_y_lvls[-1][to_next_lvl][::4]>>1
442
+ next_pos_x = pos_x_lvls[-1][to_next_lvl][::4]>>1
443
+ next_lens = [msk.sum().item()//4 for msk in to_next_lvl.split(token_lens_lvls[-1])]
444
+
445
+
446
+ token_lvls[-1] = token_lvls[-1][keep]
447
+ pos_y_lvls[-1] = pos_y_lvls[-1][keep]
448
+ pos_x_lvls[-1] = pos_x_lvls[-1][keep]
449
+ token_lens_lvls[-1] = [msk.sum().item() for msk in keep.split(token_lens_lvls[-1])]
450
+
451
+ token_lvls.append(next_tokens)
452
+ token_lens_lvls.append(next_lens)
453
+ pos_y_lvls.append(next_pos_y)
454
+ pos_x_lvls.append(next_pos_x)
455
+
456
+ if lvl < self.sat_cfg['num_lvls']-3:
457
+ lvl_zorders = lvl_zorders[to_next_lvl][::4]>>2
458
+ lvl_bids = lvl_bids[to_next_lvl][::4]
459
+
460
+ z_starts_idx = torch.where((lvl_zorders&3)==0)[0]
461
+ padded_z = torch.cat([lvl_zorders, pad_vals])
462
+ padded_bids = torch.cat([lvl_bids, pad_vals])
463
+ valids = (padded_z[z_starts_idx] + 3 == padded_z[z_starts_idx + 3]) & (padded_bids[z_starts_idx] == padded_bids[z_starts_idx + 3])
464
+ valid_starts = z_starts_idx[valids]
465
+
466
+ to_next_lvl = torch.zeros_like(lvl_zorders, dtype=bool)
467
+ to_next_lvl[valid_starts] = True
468
+ to_next_lvl[valid_starts+1] = True
469
+ to_next_lvl[valid_starts+2] = True
470
+ to_next_lvl[valid_starts+3] = True
471
+
472
+ keep = ~to_next_lvl
473
+
474
+ norm_pos_y_lvls = [(pos_y.to(dtype=lvl0_tokens.dtype) + 0.5)/self.feature_size[i] for i, pos_y in enumerate(pos_y_lvls)]
475
+ norm_pos_x_lvls = [(pos_x.to(dtype=lvl0_tokens.dtype) + 0.5)/self.feature_size[i] for i, pos_x in enumerate(pos_x_lvls)]
476
+
477
+ x_list = [torch.cat([cr, *lvls]).unsqueeze(0) \
478
+ for cr, *lvls \
479
+ in zip(cr_token_list, *[tokens.split(lens) for (tokens, lens) in zip(token_lvls, token_lens_lvls)])]
480
+ pos_y_list = [torch.cat([*lvls]) \
481
+ for lvls \
482
+ in zip(*[pos_y.split(lens) for (pos_y, lens) in zip(norm_pos_y_lvls, token_lens_lvls)])]
483
+ pos_x_list = [torch.cat([*lvls]) \
484
+ for lvls \
485
+ in zip(*[pos_x.split(lens) for (pos_x, lens) in zip(norm_pos_x_lvls, token_lens_lvls)])]
486
+ lvl_list = [torch.cat([torch.full_like(lvl, i, dtype=torch.int64) for i, lvl in enumerate(lvls)]) \
487
+ for lvls \
488
+ in zip(*[pos_x.split(lens) for (pos_x, lens) in zip(norm_pos_x_lvls, token_lens_lvls)])]
489
+
490
+
491
+
492
+ start = self.sat_cfg['get_map_layer'] if self.use_sat else 0
493
+ _, final_feature_list = self.encoder.forward_specific_layers_list(x_list, start = start, norm=True)
494
+
495
+ # proj
496
+ token_lens = [feature.shape[1] for feature in final_feature_list]
497
+ final_features = self.feature_proj(torch.cat(final_feature_list,dim=1).squeeze(0)) # (sum(L), C)
498
+ assert tuple(final_features.shape) == (sum(token_lens), self.hidden_dim)
499
+ # positional encoding
500
+ pos_embeds = position_encoding_xy(torch.cat(pos_x_list,dim=0), torch.cat(pos_y_list,dim=0), embedding_dim=self.hidden_dim)
501
+ if self.use_sat and self.sat_cfg['lvl_embed']:
502
+ lvl_embeds = self.level_embed[torch.cat(lvl_list,dim=0)]
503
+ pos_embeds = pos_embeds + lvl_embeds
504
+
505
+ sat_dict = {'pos_y': pos_y_list, 'pos_x': pos_x_list, 'lvl': lvl_list,
506
+ # 'features': [feature.squeeze(0) for feature in final_feature_list],
507
+ 'lens': token_lens}
508
+
509
+ return final_features, pos_embeds, token_lens, scale_map_dict, sat_dict
510
+
511
+ def process_smpl(self, poses, shapes, cam_xys, cam_intrinsics, detach_j3ds = False):
512
+ bs, num_queries, _ = poses.shape # should be (bs,n_q,num_poses*3)
513
+
514
+ # flatten and compute
515
+ poses = poses.flatten(0,1) # (bs*n_q,24*3)
516
+ shapes = shapes.flatten(0,1) # (bs*n_q,10)
517
+ verts, joints = self.human_model(poses=poses,
518
+ betas=shapes)
519
+ num_verts = verts.shape[1]
520
+ num_joints = joints.shape[1]
521
+ verts = verts.reshape(bs,num_queries,num_verts,3)
522
+ joints = joints.reshape(bs,num_queries,num_joints,3)
523
+
524
+ # apply cam_trans and projection
525
+ scale = 2*cam_xys[:,:,2:].sigmoid() + 1e-6
526
+ t_xy = cam_xys[:,:,:2]/scale
527
+ t_z = (2*self.focal)/(scale*self.input_size) # (bs,num_queries,1)
528
+ transl = torch.cat([t_xy,t_z],dim=2)[:,:,None,:] # (bs,nq,1,3)
529
+
530
+ verts_cam = verts + transl # only for visualization and evaluation
531
+ j3ds_cam = joints + transl
532
+
533
+ if detach_j3ds:
534
+ j2ds_homo = torch.matmul(joints.detach() + transl, cam_intrinsics.transpose(2,3))
535
+ else:
536
+ j2ds_homo = torch.matmul(j3ds_cam, cam_intrinsics.transpose(2,3))
537
+ j2ds_img = (j2ds_homo[..., :2] / (j2ds_homo[..., 2, None] + 1e-6)).reshape(bs,num_queries,num_joints,2)
538
+
539
+ depths = j3ds_cam[:,:,0,2:] # (bs, n_q, 1)
540
+ depths = torch.cat([depths, depths/self.focal], dim=-1) # (bs, n_q, 2)
541
+
542
+ return verts_cam, j3ds_cam, j2ds_img, depths, transl.flatten(2)
543
+
544
+
545
+ def forward(self, samples: NestedTensor, targets, sat_use_gt = False, detach_j3ds = False):
546
+ """ The forward expects a NestedTensor, which consists of:
547
+ - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
548
+ - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
549
+
550
+ It returns a dict with the following elements:
551
+ - "pred_logits": the classification logits (including no-object) for all queries.
552
+ Shape= [batch_size x num_queries x num_classes]
553
+ - "pred_boxes": The normalized boxes coordinates for all queries, represented as
554
+ (center_x, center_y, width, height). These values are normalized in [0, 1],
555
+ relative to the size of each individual image (disregarding possible padding).
556
+ See PostProcess for information on how to retrieve the unnormalized bounding box.
557
+ - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
558
+ dictionnaries containing the two above keys for each decoder layer.
559
+ """
560
+
561
+ assert isinstance(samples, (list, torch.Tensor))
562
+
563
+ if self.training:
564
+ self.preprocessed_pos_lvl1 = None
565
+
566
+ elif self.preprocessed_pos_lvl1 is None and self.use_sat:
567
+ self.preprocessed_pos_lvl1 = F.interpolate(self.encoder_pos_embeds.unsqueeze(0).permute(0, 3, 1, 2),
568
+ mode="bicubic",
569
+ antialias=self.encoder.interpolate_antialias,
570
+ size = (int(self.feature_size[1]),int(self.feature_size[1]))).squeeze(0).permute(1,2,0)
571
+
572
+
573
+ bs = len(targets)
574
+
575
+ # get cam_intrinsics
576
+ img_size = torch.stack([t['img_size'].flip(0) for t in targets])
577
+ valid_ratio = img_size/self.input_size
578
+
579
+ cam_intrinsics = self.cam_intrinsics.repeat(bs, 1, 1, 1)
580
+ cam_intrinsics[...,:2,2] = cam_intrinsics[...,:2,2] * valid_ratio[:, None, :]
581
+
582
+
583
+ final_features, pos_embeds, token_lens, scale_map_dict, sat_dict\
584
+ = self.forward_encoder(samples, targets, use_gt = sat_use_gt)
585
+
586
+ # default dab-detr pipeline
587
+ embedweight = (self.refpoint_embed.weight).unsqueeze(0).repeat(bs,1,1)
588
+ tgt = (self.tgt_embed.weight).unsqueeze(0).repeat(bs,1,1)
589
+
590
+ if self.training and self.use_dn:
591
+ input_query_tgt, input_query_bbox, attn_mask, dn_meta =\
592
+ prepare_for_cdn(targets = targets, dn_cfg = self.dn_cfg,
593
+ num_queries = self.num_queries, hidden_dim = self.hidden_dim, dn_enc = self.dn_enc)
594
+ tgt = torch.cat([input_query_tgt, tgt], dim=1)
595
+ embedweight = torch.cat([input_query_bbox, embedweight], dim=1)
596
+ else:
597
+ attn_mask = None
598
+
599
+ tgt_lens = [tgt.shape[1]]*bs
600
+
601
+ hs, reference = self.decoder(memory=final_features, memory_lens=token_lens,
602
+ tgt=tgt.flatten(0,1), tgt_lens=tgt_lens,
603
+ refpoint_embed=embedweight.flatten(0,1),
604
+ pos_embed=pos_embeds,
605
+ self_attn_mask = attn_mask)
606
+
607
+ reference_before_sigmoid = inverse_sigmoid(reference)
608
+ outputs_coords = []
609
+ for lvl in range(hs.shape[0]):
610
+ tmp = self.bbox_embed[lvl](hs[lvl])
611
+ tmp[..., :self.query_dim] += reference_before_sigmoid[lvl]
612
+ outputs_coord = tmp.sigmoid()
613
+ outputs_coords.append(outputs_coord)
614
+ pred_boxes = torch.stack(outputs_coords)
615
+
616
+ outputs_poses = []
617
+ outputs_shapes = []
618
+ outputs_confs = []
619
+ outputs_j3ds = []
620
+ outputs_j2ds = []
621
+ outputs_depths = []
622
+
623
+ # shape of hs: (lvl, bs, num_queries, dim)
624
+ outputs_pose_6d = self.mean_pose.view(1, 1, -1)
625
+ outputs_shape = self.mean_shape.view(1, 1, -1)
626
+ for lvl in range(hs.shape[0]):
627
+
628
+ outputs_pose_6d = outputs_pose_6d + self.pose_head[lvl](hs[lvl])
629
+ outputs_shape = outputs_shape + self.shape_head[lvl](hs[lvl])
630
+
631
+ if self.training or lvl == hs.shape[0] - 1:
632
+ outputs_pose = rot6d_to_axis_angle(outputs_pose_6d)
633
+
634
+ outputs_conf = self.conf_head(hs[lvl]).sigmoid()
635
+
636
+ # cam
637
+ cam_xys = self.cam_head(hs[lvl])
638
+
639
+ outputs_vert, outputs_j3d, outputs_j2d, depth, transl\
640
+ = self.process_smpl(poses = outputs_pose,
641
+ shapes = outputs_shape,
642
+ cam_xys = cam_xys,
643
+ cam_intrinsics = cam_intrinsics,
644
+ detach_j3ds = detach_j3ds)
645
+
646
+ outputs_poses.append(outputs_pose)
647
+ outputs_shapes.append(outputs_shape)
648
+ outputs_confs.append(outputs_conf)
649
+ # outputs_verts.append(outputs_vert)
650
+ outputs_j3ds.append(outputs_j3d)
651
+ outputs_j2ds.append(outputs_j2d)
652
+ outputs_depths.append(depth)
653
+
654
+ pred_poses = torch.stack(outputs_poses)
655
+ pred_betas = torch.stack(outputs_shapes)
656
+ pred_confs = torch.stack(outputs_confs)
657
+ pred_verts = outputs_vert
658
+ pred_transl = transl
659
+ pred_intrinsics = cam_intrinsics
660
+ pred_j3ds = torch.stack(outputs_j3ds)
661
+ pred_j2ds = torch.stack(outputs_j2ds)
662
+ pred_depths = torch.stack(outputs_depths)
663
+
664
+
665
+
666
+ if self.training > 0 and self.use_dn:
667
+ pred_poses, pred_betas,\
668
+ pred_boxes, pred_confs,\
669
+ pred_j3ds, pred_j2ds, pred_depths,\
670
+ pred_verts, pred_transl =\
671
+ dn_post_process(pred_poses, pred_betas,
672
+ pred_boxes, pred_confs,
673
+ pred_j3ds, pred_j2ds, pred_depths,
674
+ pred_verts, pred_transl,
675
+ dn_meta, self.aux_loss, self._set_aux_loss)
676
+
677
+
678
+ out = {'pred_poses': pred_poses[-1], 'pred_betas': pred_betas[-1],
679
+ 'pred_boxes': pred_boxes[-1], 'pred_confs': pred_confs[-1],
680
+ 'pred_j3ds': pred_j3ds[-1], 'pred_j2ds': pred_j2ds[-1],
681
+ 'pred_verts': pred_verts, 'pred_intrinsics': pred_intrinsics,
682
+ 'pred_depths': pred_depths[-1], 'pred_transl': pred_transl}
683
+
684
+ if self.aux_loss and self.training:
685
+ out['aux_outputs'] = self._set_aux_loss(pred_poses, pred_betas,
686
+ pred_boxes, pred_confs,
687
+ pred_j3ds, pred_j2ds, pred_depths)
688
+
689
+ if self.use_sat:
690
+ out['enc_outputs'] = scale_map_dict
691
+
692
+ out['sat'] = sat_dict
693
+
694
+ if self.training > 0 and self.use_dn:
695
+ out['dn_meta'] = dn_meta
696
+
697
+ return out
698
+
699
+ @torch.jit.unused
700
+ def _set_aux_loss(self, pred_poses, pred_betas, pred_boxes,
701
+ pred_confs, pred_j3ds,
702
+ pred_j2ds, pred_depths):
703
+ # this is a workaround to make torchscript happy, as torchscript
704
+ # doesn't support dictionary with non-homogeneous values, such
705
+ # as a dict having both a Tensor and a list.
706
+ return [{'pred_poses': a, 'pred_betas': b,
707
+ 'pred_boxes': c, 'pred_confs': d,
708
+ 'pred_j3ds': e, 'pred_j2ds': f, 'pred_depths': g}
709
+ for a, b, c, d, e, f, g in zip(pred_poses[:-1], pred_betas[:-1],
710
+ pred_boxes[:-1], pred_confs[:-1], pred_j3ds[:-1], pred_j2ds[:-1], pred_depths[:-1])]
711
+
712
+
713
+
714
+ class MLP(nn.Module):
715
+ """ Very simple multi-layer perceptron (also called FFN)"""
716
+
717
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
718
+ super().__init__()
719
+ self.num_layers = num_layers
720
+ h = [hidden_dim] * (num_layers - 1)
721
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
722
+
723
+ def forward(self, x):
724
+ for i, layer in enumerate(self.layers):
725
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
726
+ return x
727
+
728
+
729
+ def build_sat_model(args, set_criterion=True):
730
+ encoder = build_encoder(args)
731
+ decoder = build_decoder(args)
732
+
733
+ model = Model(
734
+ encoder,
735
+ decoder,
736
+ num_queries=args.num_queries,
737
+ input_size=args.input_size,
738
+ sat_cfg=args.sat_cfg,
739
+ dn_cfg=args.dn_cfg,
740
+ train_pos_embed=getattr(args,'train_pos_embed',True)
741
+ )
742
+
743
+
744
+ if set_criterion:
745
+ matcher = build_matcher(args)
746
+ weight_dict = args.weight_dict
747
+ losses = args.losses
748
+
749
+ if args.dn_cfg['use_dn']:
750
+ dn_weight_dict = {}
751
+ dn_weight_dict.update({f'{k}_dn': v for k, v in weight_dict.items()})
752
+ weight_dict.update(dn_weight_dict)
753
+
754
+ aux_weight_dict = {}
755
+ for i in range(args.dec_layers - 1):
756
+ aux_weight_dict.update({f'{k}.{i}': v for k, v in weight_dict.items()})
757
+ weight_dict.update(aux_weight_dict)
758
+
759
+ if args.sat_cfg['use_sat']:
760
+ if 'map_confs' not in weight_dict:
761
+ weight_dict.update({'map_confs': weight_dict['confs']})
762
+ # weight_dict.update({'map_scales': })
763
+
764
+ criterion = SetCriterion(matcher, weight_dict, losses = losses, j2ds_norm_scale = args.input_size)
765
+ return model, criterion
766
+ else:
767
+ return model, None
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.26.1
2
+ chumpy
3
+ smplx
4
+ opencv-python
5
+ trimesh
6
+ tensorboard
7
+ scipy
8
+ pyrender==0.1.45
9
+ joblib
10
+ termcolor
11
+ transformers
12
+ matplotlib
13
+ scikit-learn
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
utils/box_ops.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Utilities for bounding box manipulation and GIoU.
4
+ """
5
+ import torch, os
6
+ from torchvision.ops.boxes import box_area
7
+
8
+
9
+ def box_cxcywh_to_xyxy(x):
10
+ x_c, y_c, w, h = x.unbind(-1)
11
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
12
+ (x_c + 0.5 * w), (y_c + 0.5 * h)]
13
+ return torch.stack(b, dim=-1)
14
+
15
+
16
+ def box_xyxy_to_cxcywh(x):
17
+ x0, y0, x1, y1 = x.unbind(-1)
18
+ b = [(x0 + x1) / 2, (y0 + y1) / 2,
19
+ (x1 - x0), (y1 - y0)]
20
+ return torch.stack(b, dim=-1)
21
+
22
+
23
+ # modified from torchvision to also return the union
24
+ def box_iou(boxes1, boxes2):
25
+ area1 = box_area(boxes1)
26
+ area2 = box_area(boxes2)
27
+
28
+ # import ipdb; ipdb.set_trace()
29
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
30
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
31
+
32
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
33
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
34
+
35
+ union = area1[:, None] + area2 - inter
36
+
37
+ iou = inter / (union + 1e-6)
38
+ return iou, union
39
+
40
+
41
+ def generalized_box_iou(boxes1, boxes2):
42
+ """
43
+ Generalized IoU from https://giou.stanford.edu/
44
+
45
+ The boxes should be in [x0, y0, x1, y1] format
46
+
47
+ Returns a [N, M] pairwise matrix, where N = len(boxes1)
48
+ and M = len(boxes2)
49
+ """
50
+ # degenerate boxes gives inf / nan results
51
+ # so do an early check
52
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
53
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
54
+ # except:
55
+ # import ipdb; ipdb.set_trace()
56
+ iou, union = box_iou(boxes1, boxes2)
57
+
58
+ lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
59
+ rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
60
+
61
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
62
+ area = wh[:, :, 0] * wh[:, :, 1]
63
+
64
+ return iou - (area - union) / (area + 1e-6)
65
+
66
+
67
+
68
+ # modified from torchvision to also return the union
69
+ def box_iou_pairwise(boxes1, boxes2):
70
+ area1 = box_area(boxes1)
71
+ area2 = box_area(boxes2)
72
+
73
+ lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2]
74
+ rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2]
75
+
76
+ wh = (rb - lt).clamp(min=0) # [N,2]
77
+ inter = wh[:, 0] * wh[:, 1] # [N]
78
+
79
+ union = area1 + area2 - inter
80
+
81
+ iou = inter / union
82
+ return iou, union
83
+
84
+
85
+ def generalized_box_iou_pairwise(boxes1, boxes2):
86
+ """
87
+ Generalized IoU from https://giou.stanford.edu/
88
+
89
+ Input:
90
+ - boxes1, boxes2: N,4
91
+ Output:
92
+ - giou: N, 4
93
+ """
94
+ # degenerate boxes gives inf / nan results
95
+ # so do an early check
96
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
97
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
98
+ assert boxes1.shape == boxes2.shape
99
+ iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4
100
+
101
+ lt = torch.min(boxes1[:, :2], boxes2[:, :2])
102
+ rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])
103
+
104
+ wh = (rb - lt).clamp(min=0) # [N,2]
105
+ area = wh[:, 0] * wh[:, 1]
106
+
107
+ return iou - (area - union) / area
108
+
109
+ def masks_to_boxes(masks):
110
+ """Compute the bounding boxes around the provided masks
111
+
112
+ The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
113
+
114
+ Returns a [N, 4] tensors, with the boxes in xyxy format
115
+ """
116
+ if masks.numel() == 0:
117
+ return torch.zeros((0, 4), device=masks.device)
118
+
119
+ h, w = masks.shape[-2:]
120
+
121
+ y = torch.arange(0, h, dtype=torch.float)
122
+ x = torch.arange(0, w, dtype=torch.float)
123
+ y, x = torch.meshgrid(y, x)
124
+
125
+ x_mask = (masks * x.unsqueeze(0))
126
+ x_max = x_mask.flatten(1).max(-1)[0]
127
+ x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
128
+
129
+ y_mask = (masks * y.unsqueeze(0))
130
+ y_max = y_mask.flatten(1).max(-1)[0]
131
+ y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
132
+
133
+ return torch.stack([x_min, y_min, x_max, y_max], 1)
134
+
135
+ if __name__ == '__main__':
136
+ x = torch.rand(5, 4)
137
+ y = torch.rand(3, 4)
138
+ iou, union = box_iou(x, y)
139
+ import ipdb; ipdb.set_trace()