lhx05 commited on
Commit
fb24bef
·
verified ·
1 Parent(s): a2f2478

Upload CVLFace experiment code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. cvlface/research/recognition/code/run_v1/README.md +0 -0
  2. cvlface/research/recognition/code/run_v1/aligners/__init__.py +25 -0
  3. cvlface/research/recognition/code/run_v1/aligners/base/__init__.py +60 -0
  4. cvlface/research/recognition/code/run_v1/aligners/base/utils.py +91 -0
  5. cvlface/research/recognition/code/run_v1/aligners/configs/dfa.yaml +10 -0
  6. cvlface/research/recognition/code/run_v1/aligners/configs/none.yaml +3 -0
  7. cvlface/research/recognition/code/run_v1/aligners/configs/retinaface.yaml +3 -0
  8. cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/__init__.py +117 -0
  9. cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/aligner_helper.py +97 -0
  10. cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/__init__.py +27 -0
  11. cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/config.py +18 -0
  12. cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/__init__.py +2 -0
  13. cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/functions/prior_box.py +140 -0
  14. cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/modules/__init__.py +3 -0
  15. cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/modules/multibox_loss.py +144 -0
  16. cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/models/__init__.py +0 -0
  17. cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/models/net.py +132 -0
  18. cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/models/retinaface.py +142 -0
  19. cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/preprocessor.py +93 -0
  20. cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/utils/box_utils.py +239 -0
  21. cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/utils/model_utils.py +36 -0
  22. cvlface/research/recognition/code/run_v1/aligners/none/__init__.py +20 -0
  23. cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/__init__.py +246 -0
  24. cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/aligner_helper.py +97 -0
  25. cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/__init__.py +28 -0
  26. cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/config.py +18 -0
  27. cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/__init__.py +2 -0
  28. cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/functions/prior_box.py +140 -0
  29. cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/modules/__init__.py +3 -0
  30. cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/modules/multibox_loss.py +144 -0
  31. cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/models/__init__.py +0 -0
  32. cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/models/net.py +132 -0
  33. cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/models/retinaface.py +123 -0
  34. cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/preprocessor.py +93 -0
  35. cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/utils/box_utils.py +239 -0
  36. cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/utils/model_utils.py +36 -0
  37. cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface_pipeline.py +247 -0
  38. cvlface/research/recognition/code/run_v1/base.yaml +12 -0
  39. cvlface/research/recognition/code/run_v1/classifiers/__init__.py +31 -0
  40. cvlface/research/recognition/code/run_v1/classifiers/base/__init__.py +87 -0
  41. cvlface/research/recognition/code/run_v1/classifiers/base/utils.py +91 -0
  42. cvlface/research/recognition/code/run_v1/classifiers/configs/fc.yaml +4 -0
  43. cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc.yaml +4 -0
  44. cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_freeze.yaml +4 -0
  45. cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_sample10.yaml +4 -0
  46. cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_sample10_freeze.yaml +4 -0
  47. cvlface/research/recognition/code/run_v1/classifiers/fc/__init__.py +55 -0
  48. cvlface/research/recognition/code/run_v1/classifiers/fc/fc.py +67 -0
  49. cvlface/research/recognition/code/run_v1/classifiers/partial_fc/__init__.py +39 -0
  50. cvlface/research/recognition/code/run_v1/classifiers/partial_fc/partial_fc.py +289 -0
cvlface/research/recognition/code/run_v1/README.md ADDED
File without changes
cvlface/research/recognition/code/run_v1/aligners/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import BaseAligner
2
+
3
+
4
+ def get_aligner(aligner_cfg):
5
+
6
+ if aligner_cfg.name == 'none':
7
+ from .none import NoneAligner
8
+ aligner = NoneAligner.from_config(aligner_cfg)
9
+ elif aligner_cfg.name == 'retinaface_aligner':
10
+ from .retinaface_aligner import RetinaFaceAligner
11
+ aligner = RetinaFaceAligner.from_config(aligner_cfg)
12
+ elif aligner_cfg.name == 'differentiable_face_aligner':
13
+ from .differentiable_face_aligner import DifferentiableFaceAligner
14
+ aligner = DifferentiableFaceAligner.from_config(aligner_cfg)
15
+ else:
16
+ raise ValueError(f"Unknown classifier: {aligner_cfg.name}")
17
+
18
+ if aligner_cfg.start_from:
19
+ aligner.load_state_dict_from_path(aligner_cfg.start_from)
20
+
21
+ if aligner_cfg.freeze:
22
+ for param in aligner.parameters():
23
+ param.requires_grad = False
24
+ return aligner
25
+
cvlface/research/recognition/code/run_v1/aligners/base/__init__.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union
3
+ import torch
4
+ from torch import device
5
+ from .utils import get_parameter_device, get_parameter_dtype, save_state_dict_and_config, load_state_dict_from_path
6
+
7
+ class BaseAligner(torch.nn.Module):
8
+
9
+ def __init__(self, config=None):
10
+ super().__init__()
11
+ self.config = config
12
+
13
+ @classmethod
14
+ def from_config(cls, config) -> "BaseAligner":
15
+ raise NotImplementedError('from_config must be implemented in subclass')
16
+
17
+ def make_train_transform(self):
18
+ raise NotImplementedError('from_config must be implemented in subclass')
19
+
20
+ def make_test_transform(self):
21
+ raise NotImplementedError('from_config must be implemented in subclass')
22
+
23
+ def forward(self, x):
24
+ raise NotImplementedError('from_config must be implemented in subclass')
25
+
26
+ def save_pretrained(
27
+ self,
28
+ save_dir: Union[str, os.PathLike],
29
+ name: str = 'model.pt',
30
+ rank: int = 0,
31
+ ):
32
+ save_path = os.path.join(save_dir, name)
33
+ if rank == 0:
34
+ save_state_dict_and_config(self.state_dict(), self.config, save_path)
35
+
36
+ def load_state_dict_from_path(self, pretrained_model_path):
37
+ state_dict = load_state_dict_from_path(pretrained_model_path)
38
+ result = self.load_state_dict(state_dict)
39
+ print(f"Loaded pretrained aligner from {pretrained_model_path}")
40
+
41
+
42
+ @property
43
+ def device(self) -> device:
44
+ return get_parameter_device(self)
45
+
46
+ @property
47
+ def dtype(self) -> torch.dtype:
48
+ return get_parameter_dtype(self)
49
+
50
+ def num_parameters(self, only_trainable: bool = False) -> int:
51
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
52
+
53
+ def has_trainable_params(self):
54
+ for param in self.parameters():
55
+ if param.requires_grad:
56
+ return True
57
+ return False
58
+
59
+ def has_params(self):
60
+ return len(list(self.parameters())) > 0
cvlface/research/recognition/code/run_v1/aligners/base/utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from typing import List, Optional, Tuple, Union
3
+ import safetensors
4
+ import torch
5
+ from torch import Tensor
6
+ import os
7
+ from pathlib import Path
8
+ from omegaconf import DictConfig, OmegaConf
9
+
10
+
11
+ def get_parameter_device(parameter: torch.nn.Module):
12
+ try:
13
+ parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
14
+ return next(parameters_and_buffers).device
15
+ except StopIteration:
16
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
17
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
18
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
19
+ return tuples
20
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
21
+ first_tuple = next(gen)
22
+ return first_tuple[1].device
23
+
24
+
25
+ def get_parameter_dtype(parameter: torch.nn.Module):
26
+ try:
27
+ params = tuple(parameter.parameters())
28
+ if len(params) > 0:
29
+ return params[0].dtype
30
+
31
+ buffers = tuple(parameter.buffers())
32
+ if len(buffers) > 0:
33
+ return buffers[0].dtype
34
+
35
+ except StopIteration:
36
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
37
+
38
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
39
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
40
+ return tuples
41
+
42
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
43
+ first_tuple = next(gen)
44
+ return first_tuple[1].dtype
45
+
46
+
47
+ def get_parent_directory(save_path: Union[str, os.PathLike]) -> Path:
48
+ path_obj = Path(save_path)
49
+ return path_obj.parent
50
+
51
+ def get_base_name(save_path: Union[str, os.PathLike]) -> str:
52
+ path_obj = Path(save_path)
53
+ return path_obj.name
54
+
55
+ def load_state_dict_from_path(path: Union[str, os.PathLike]):
56
+ # Load a state dict from a path.
57
+ if 'safetensors' in path:
58
+ state_dict = safetensors.torch.load_file(path)
59
+ else:
60
+ state_dict = torch.load(path, map_location="cpu")
61
+ return state_dict
62
+
63
+ def replace_extension(path, new_extension):
64
+ if not new_extension.startswith('.'):
65
+ new_extension = '.' + new_extension
66
+ return os.path.splitext(path)[0] + new_extension
67
+
68
+ def make_config_path(save_path):
69
+ config_path = replace_extension(save_path, '.yaml')
70
+ return config_path
71
+
72
+ def save_config(config, config_path):
73
+ assert isinstance(config, dict) or isinstance(config, DictConfig)
74
+ os.makedirs(get_parent_directory(config_path), exist_ok=True)
75
+ if isinstance(config, dict):
76
+ config = OmegaConf.create(config)
77
+ OmegaConf.save(config, config_path)
78
+
79
+
80
+ def save_state_dict_and_config(state_dict, config, save_path):
81
+ os.makedirs(get_parent_directory(save_path), exist_ok=True)
82
+
83
+ # save config dict
84
+ config_path = make_config_path(save_path)
85
+ save_config(config, config_path)
86
+
87
+ # Save the model
88
+ if 'safetensors' in save_path:
89
+ safetensors.torch.save_file(state_dict, save_path, metadata={"format": "pt"})
90
+ else:
91
+ torch.save(state_dict, save_path)
cvlface/research/recognition/code/run_v1/aligners/configs/dfa.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ name: differentiable_face_aligner
2
+ arch: 'mobile0.25'
3
+ start_from: '../../../../pretrained_models/alignment/dfa_mobilenet/mobilenet0.25.pth'
4
+ freeze: True
5
+
6
+ input_padding_ratio: 0 # pad the input to this size before resize
7
+ input_padding_val: 'zero'
8
+ input_size: 160 # resize the input to this size
9
+ output_size: 112 # size of the output of aligner
10
+ color_space: 'RGB' # color space of the input image
cvlface/research/recognition/code/run_v1/aligners/configs/none.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ name: none
2
+ start_from: ''
3
+ freeze: False
cvlface/research/recognition/code/run_v1/aligners/configs/retinaface.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ name: retinaface
2
+ start_from: ''
3
+ freeze: True
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/__init__.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..base import BaseAligner
2
+ from torchvision import transforms
3
+ from .dfa import get_landmark_predictor, get_preprocessor
4
+ from . import aligner_helper
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+
9
+
10
+ class DifferentiableFaceAligner(BaseAligner):
11
+
12
+ '''
13
+ A differentiable face aligner that aligns the image with one face to a canonical position.
14
+ The aligner is based on the following paper (check out supplementary material for more details):
15
+ @inproceedings{kim2024kprpe,
16
+ title={{KeyPoint Relative Position Encoding for Face Recognition},
17
+ author={Kim, Minchul and Su, Yiyang and Liu, Feng and Liu, Xiaoming},
18
+ booktitle={CVPR},
19
+ year={2024}
20
+ }
21
+ '''
22
+
23
+ def __init__(self, net, prior_box, preprocessor, config):
24
+ super(DifferentiableFaceAligner, self).__init__()
25
+ self.net = net
26
+ self.prior_box = prior_box
27
+ self.preprocessor = preprocessor
28
+ self.config = config
29
+
30
+ @classmethod
31
+ def from_config(cls, config):
32
+ net, prior_box = get_landmark_predictor(network=config.arch,
33
+ use_aggregator=True,
34
+ input_size=config.input_size)
35
+
36
+ preprocessor = get_preprocessor(output_size=config.input_size,
37
+ padding=config.input_padding_ratio,
38
+ padding_val=config.input_padding_val)
39
+ if config.freeze:
40
+ for param in net.parameters():
41
+ param.requires_grad = False
42
+ model = cls(net, prior_box, preprocessor, config)
43
+ model.eval()
44
+ return model
45
+
46
+ def forward(self, x, padding_ratio_override=None):
47
+
48
+ # input size check
49
+ assert x.shape[1] == 3
50
+ assert x.ndim == 4
51
+ assert isinstance(x, torch.Tensor)
52
+ is_square = x.shape[2] == x.shape[3]
53
+
54
+ x = self.preprocessor(x, padding_ratio_override=padding_ratio_override)
55
+ assert self.prior_box.image_size == x.shape[2:]
56
+
57
+ # make image into BGR
58
+ x_bgr = x.flip(1)
59
+ result = self.net(x_bgr, self.prior_box)
60
+ orig_pred_ldmks, bbox, cls = aligner_helper.split_network_output(result)
61
+ score = torch.nn.Softmax(dim=-1)(cls)[:,1:]
62
+
63
+ reference_ldmk = aligner_helper.reference_landmark()
64
+ input_size = self.config.input_size
65
+ output_size = self.config.output_size
66
+ cv2_tfms = aligner_helper.get_cv2_affine_from_landmark(orig_pred_ldmks, reference_ldmk, input_size, input_size)
67
+ thetas = aligner_helper.cv2_param_to_torch_theta(cv2_tfms, input_size, input_size, output_size, output_size)
68
+ thetas = thetas.to(orig_pred_ldmks.device)
69
+
70
+ output_size = torch.Size((len(thetas), 3, output_size, output_size))
71
+ grid = F.affine_grid(thetas, output_size, align_corners=True)
72
+ aligned_x = F.grid_sample(x + 1, grid, align_corners=True) - 1 # +1, -1 for making padding pixel 0
73
+ aligned_ldmks = aligner_helper.adjust_ldmks(orig_pred_ldmks.view(-1, 5, 2), thetas)
74
+
75
+ orig_pred_ldmks = orig_pred_ldmks.view(-1, 5, 2)
76
+ # bbox (xmin, ymin, xmax, ymax)
77
+ normalized_bbox = bbox / torch.tensor([[x_bgr.size(3), x_bgr.size(2)] * 2]).to(bbox.device)
78
+
79
+
80
+ if padding_ratio_override is None:
81
+ padding_ratio = self.preprocessor.padding
82
+ else:
83
+ padding_ratio = padding_ratio_override
84
+ if padding_ratio > 0:
85
+ # unpad the landmark so that it is in the original image coordinate
86
+ scale = 1 / (1 + (2 * padding_ratio))
87
+ pad_inv_theta = torch.from_numpy(np.array([[1 / scale, 0, 0], [0, 1 / scale, 0]]))
88
+ pad_inv_theta = pad_inv_theta.unsqueeze(0).float().to(self.device).repeat(orig_pred_ldmks.size(0), 1, 1)
89
+ unpad_ldmk_pred = torch.concat([orig_pred_ldmks.view(-1, 5, 2),
90
+ torch.ones((orig_pred_ldmks.size(0), 5, 1)).to(self.device)], dim=-1)
91
+ unpad_ldmk_pred = (((unpad_ldmk_pred) * 2 - 1) @ pad_inv_theta.mT) / 2 + 0.5
92
+ unpad_ldmk_pred = unpad_ldmk_pred.view(orig_pred_ldmks.size(0), -1).detach()
93
+ unpad_ldmk_pred = unpad_ldmk_pred.view(-1, 5, 2)
94
+ if not is_square:
95
+ unpad_ldmk_pred = None # cannot use this if the input is not square becaouse preprocessor changes input
96
+ normalized_bbox = None # cannot use this if the input is not square becaouse preprocessor changes input
97
+ return aligned_x, unpad_ldmk_pred, aligned_ldmks, score, thetas, normalized_bbox
98
+
99
+ if not is_square:
100
+ orig_pred_ldmks = None # cannot use this if the input is not square becaouse preprocessor changes input
101
+ normalized_bbox = None # cannot use this if the input is not square becaouse preprocessor changes input
102
+ return aligned_x, orig_pred_ldmks, aligned_ldmks, score, thetas, normalized_bbox
103
+
104
+ def make_train_transform(self):
105
+ transform = transforms.Compose([
106
+ transforms.ToTensor(),
107
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
108
+ ])
109
+ return transform
110
+
111
+ def make_test_transform(self):
112
+ transform = transforms.Compose([
113
+ transforms.ToTensor(),
114
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
115
+ ])
116
+ return transform
117
+
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/aligner_helper.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ from skimage import transform as trans
5
+ import cv2
6
+
7
+
8
+ def split_network_output(align_out):
9
+ anchor_bbox_pred, anchor_cls_pred, anchor_ldmk_pred, merged, _ = align_out
10
+ bbox, cls, ldmk = torch.split(merged, [4, 2, 10], dim=1)
11
+ return ldmk, bbox, cls
12
+
13
+
14
+ def get_cv2_affine_from_landmark(ldmks, reference_ldmk, image_width, image_height, ):
15
+ assert ldmks.ndim == 2 # batchdim
16
+ assert ldmks.shape[1] == 10
17
+ assert isinstance(ldmks, torch.Tensor)
18
+
19
+ assert reference_ldmk.ndim == 2
20
+ assert reference_ldmk.shape[0] == 5
21
+ assert reference_ldmk.shape[1] == 2
22
+ assert isinstance(reference_ldmk, np.ndarray)
23
+
24
+ to_img_size = np.array([[[image_width, image_height]]])
25
+ ldmks = ldmks.view(ldmks.shape[0], 5, 2).detach().cpu().numpy()
26
+ ldmks = ldmks * to_img_size
27
+ transforms = []
28
+ for ldmk in ldmks:
29
+ tform = trans.SimilarityTransform()
30
+ tform.estimate(ldmk, reference_ldmk)
31
+ M = tform.params[0:2, :]
32
+ transforms.append(M)
33
+ transforms = np.stack(transforms, axis=0)
34
+ return transforms
35
+
36
+
37
+ def cv2_param_to_torch_theta(cv2_tfms, image_width, image_height, output_width, output_height):
38
+ # https://github.com/wuneng/WarpAffine2GridSample
39
+ """4.Affine Transformation Matrix to theta"""
40
+ assert cv2_tfms.ndim == 3 # N, 2, 3
41
+ assert cv2_tfms.shape[1] == 2
42
+ assert cv2_tfms.shape[2] == 3
43
+
44
+ srcs = np.array([[0, 0], [0, 1], [1, 1]], dtype=np.float32)
45
+ srcs = np.expand_dims(srcs, axis=0).repeat(cv2_tfms.shape[0], axis=0)
46
+ dsts = np.matmul(srcs, cv2_tfms[:, :, :2].transpose(0, 2, 1)) + cv2_tfms[:, :, 2:3].transpose(0, 2, 1)
47
+
48
+ # normalize to [-1, 1]
49
+ srcs = srcs / np.array([[[image_width, image_height]]]) * 2 - 1
50
+ dsts = dsts / np.array([[[output_width, output_height]]]) * 2 - 1
51
+
52
+ thetas = []
53
+ for src, dst in zip(srcs, dsts):
54
+ theta = trans.estimate_transform("affine", src=dst, dst=src).params[:2]
55
+ thetas.append(theta)
56
+ thetas = np.stack(thetas, axis=0)
57
+ thetas = torch.from_numpy(thetas).float()
58
+ return thetas
59
+
60
+
61
+ def adjust_ldmks(ldmks, thetas):
62
+ inv_thetas = inv_matrix(thetas).to(ldmks.device).float()
63
+ _ldmks = torch.cat([ldmks, torch.ones((ldmks.shape[0], 5, 1)).to(ldmks.device)], dim=2)
64
+ ldmk_aligned = (((_ldmks) * 2 - 1) @ inv_thetas.permute(0,2,1)) / 2 + 0.5
65
+ return ldmk_aligned
66
+
67
+
68
+ def inv_matrix(theta):
69
+ # torch batched version
70
+ assert theta.ndim == 3
71
+ a, b, t1 = theta[:, 0,0], theta[:, 0,1], theta[:, 0,2]
72
+ c, d, t2 = theta[:, 1,0], theta[:, 1,1], theta[:, 1,2]
73
+ det = a * d - b * c
74
+ inv_det = 1.0 / det
75
+ inv_mat = torch.stack([
76
+ torch.stack([d * inv_det, -b * inv_det, (b * t2 - d * t1) * inv_det], dim=1),
77
+ torch.stack([-c * inv_det, a * inv_det, (c * t1 - a * t2) * inv_det], dim=1)
78
+ ], dim=1)
79
+ return inv_mat
80
+
81
+ def reference_landmark():
82
+ return np.array([[38.29459953, 51.69630051],
83
+ [73.53179932, 51.50139999],
84
+ [56.02519989, 71.73660278],
85
+ [41.54930115, 92.3655014],
86
+ [70.72990036, 92.20410156]])
87
+
88
+
89
+ def draw_ldmk(img, ldmk):
90
+ if ldmk is None:
91
+ return img
92
+ colors = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255), (255, 0, 255)]
93
+ img = img.copy()
94
+ for i in range(5):
95
+ color = colors[i]
96
+ cv2.circle(img, (int(ldmk[i*2] * img.shape[1]), int(ldmk[i*2+1] * img.shape[0])), 1, color, 4)
97
+ return img
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .models.retinaface import RetinaFace
2
+ from .utils.model_utils import load_model
3
+ from .config import cfg_mnet, cfg_re50
4
+ from .layers.functions.prior_box import PriorBox
5
+ from .preprocessor import Preprocessor
6
+
7
+ def get_landmark_predictor(network='mobile0.25', use_aggregator=True, input_size=160):
8
+
9
+ cfg = None
10
+ if network == "mobile0.25":
11
+ cfg = cfg_mnet
12
+ elif network == "resnet50":
13
+ cfg = cfg_re50
14
+ net = RetinaFace(cfg=cfg, phase = 'test', use_aggregator=use_aggregator)
15
+ priorbox = PriorBox(image_size=(input_size, input_size),
16
+ min_sizes=[[64, 80], [96, 112], [128, 144]],
17
+ steps=[8, 16, 32],
18
+ clip=False,
19
+ variances=[0.1, 0.2],)
20
+
21
+ # aligner = Aligner(net, priorbox, input_size, output_size=output_size)
22
+ # return aligner
23
+ return net, priorbox
24
+
25
+
26
+ def get_preprocessor(output_size=160, padding=0.0, padding_val='zero'):
27
+ return Preprocessor(output_size=output_size, padding=padding, padding_val=padding_val)
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/config.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+
3
+ cfg_mnet = {
4
+ 'name': 'mobilenet0.25',
5
+ 'pretrain': True,
6
+ 'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3},
7
+ 'in_channel': 32,
8
+ 'out_channel': 64
9
+ }
10
+
11
+ cfg_re50 = {
12
+ 'name': 'Resnet50',
13
+ 'pretrain': True,
14
+ 'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3},
15
+ 'in_channel': 256,
16
+ 'out_channel': 256
17
+ }
18
+
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .functions import *
2
+ from .modules import *
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/functions/prior_box.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from itertools import product as product
3
+ from math import ceil
4
+
5
+
6
+ class PriorBox(object):
7
+
8
+ def __init__(self,
9
+ image_size,
10
+ min_sizes=[[64, 80], [96, 112], [128, 144]],
11
+ steps=[8,16,32],
12
+ clip=False,
13
+ variances=[0.1, 0.2],
14
+ ):
15
+ super(PriorBox, self).__init__()
16
+ self.min_sizes = min_sizes
17
+ self.steps = steps
18
+ self.clip = clip
19
+ self.variances = variances
20
+ self.image_size = image_size
21
+ self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps]
22
+ with torch.no_grad():
23
+ self.priors = self.forward()
24
+
25
+ def forward(self):
26
+ anchors = []
27
+ for k, f in enumerate(self.feature_maps):
28
+ min_sizes = self.min_sizes[k]
29
+ for i, j in product(range(f[0]), range(f[1])):
30
+ for min_size in min_sizes:
31
+ s_kx = min_size / self.image_size[1]
32
+ s_ky = min_size / self.image_size[0]
33
+ dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
34
+ dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
35
+ for cy, cx in product(dense_cy, dense_cx):
36
+ anchors += [cx, cy, s_kx, s_ky]
37
+
38
+ # back to torch land
39
+ output = torch.Tensor(anchors).view(-1, 4)
40
+ # import pandas as pd
41
+ # pd.DataFrame(output.numpy()).to_csv('/mckim/temp/temp.csv')
42
+ if self.clip:
43
+ output.clamp_(max=1, min=0)
44
+ return output
45
+
46
+ def encode(self, matched):
47
+ """Encode the variances from the priorbox layers into the ground truth boxes
48
+ we have matched (based on jaccard overlap) with the prior boxes.
49
+ """
50
+ self.priors = self.priors.to(matched.device)
51
+
52
+ # dist b/t match center and prior's center
53
+ g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - self.priors[:, :2]
54
+ # encode variance
55
+ g_cxcy /= (self.variances[0] * self.priors[:, 2:])
56
+ # match wh / prior wh
57
+ g_wh = (matched[:, 2:] - matched[:, :2]) / self.priors[:, 2:]
58
+ g_wh = torch.log(g_wh) / self.variances[1]
59
+ # return target for smooth_l1_loss
60
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
61
+
62
+ def encode_landm(self, matched):
63
+ """Encode the variances from the priorbox layers into the ground truth boxes
64
+ we have matched (based on jaccard overlap) with the prior boxes.
65
+ """
66
+ self.priors = self.priors.to(matched.device)
67
+
68
+ # dist b/t match center and prior's center
69
+ matched = torch.reshape(matched, (matched.size(0), 5, 2))
70
+ priors_cx = self.priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
71
+ priors_cy = self.priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
72
+ priors_w = self.priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
73
+ priors_h = self.priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
74
+ priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
75
+ g_cxcy = matched[:, :, :2] - priors[:, :, :2]
76
+ # encode variance
77
+ g_cxcy /= (self.variances[0] * priors[:, :, 2:])
78
+ # g_cxcy /= priors[:, :, 2:]
79
+ g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
80
+ # return target for smooth_l1_loss
81
+ return g_cxcy
82
+
83
+
84
+ # Adapted from https://github.com/Hakuyume/chainer-ssd
85
+ def decode(self, loc):
86
+ """Decode locations from predictions using priors to undo
87
+ the encoding we did for offset regression at train time.
88
+ """
89
+ self.priors = self.priors.to(loc.device)
90
+
91
+ boxes = torch.cat((
92
+ self.priors[:, :2] + loc[:, :2] * self.variances[0] * self.priors[:, 2:],
93
+ self.priors[:, 2:] * torch.exp(loc[:, 2:] * self.variances[1])), 1)
94
+ boxes[:, :2] -= boxes[:, 2:] / 2
95
+ boxes[:, 2:] += boxes[:, :2]
96
+ return boxes
97
+
98
+ def decode_landm(self, pre):
99
+ """Decode landm from predictions using priors to undo
100
+ the encoding we did for offset regression at train time.
101
+ """
102
+ self.priors = self.priors.to(pre.device)
103
+ landms = torch.cat((self.priors[:, :2] + pre[:, :2] * self.variances[0] * self.priors[:, 2:],
104
+ self.priors[:, :2] + pre[:, 2:4] * self.variances[0] * self.priors[:, 2:],
105
+ self.priors[:, :2] + pre[:, 4:6] * self.variances[0] * self.priors[:, 2:],
106
+ self.priors[:, :2] + pre[:, 6:8] * self.variances[0] * self.priors[:, 2:],
107
+ self.priors[:, :2] + pre[:, 8:10] * self.variances[0] * self.priors[:, 2:],
108
+ ), dim=1)
109
+ return landms
110
+
111
+
112
+ def decode_batch(self, loc):
113
+ """Decode locations from predictions using priors to undo
114
+ the encoding we did for offset regression at train time.
115
+ """
116
+ self.priors = self.priors.to(loc.device)
117
+ assert loc.ndim == 3
118
+ priors = self.priors.unsqueeze(0).expand(loc.size(0), -1, -1)
119
+ boxes = torch.cat((
120
+ priors[:, :, :2] + loc[:, :, :2] * self.variances[0] * priors[:, :, 2:],
121
+ priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * self.variances[1])), -1)
122
+ boxes[:, :, :2] -= boxes[:, :, 2:] / 2
123
+ boxes[:, :, 2:] += boxes[:, :, :2]
124
+ return boxes
125
+
126
+
127
+ def decode_landm_batch(self, prediction):
128
+ """Decode landm from prediction using priors to undo
129
+ the encoding we did for offset regression at train time.
130
+ """
131
+ assert prediction.ndim == 3
132
+ self.priors = self.priors.to(prediction.device)
133
+ priors = self.priors.unsqueeze(0).expand(prediction.size(0), -1, -1)
134
+ landms = torch.cat((priors[:, :, :2] + prediction[:, :, :2] * self.variances[0] * priors[:, :, 2:],
135
+ priors[:, :, :2] + prediction[:, :, 2:4] * self.variances[0] * priors[:, :, 2:],
136
+ priors[:, :, :2] + prediction[:, :, 4:6] * self.variances[0] * priors[:, :, 2:],
137
+ priors[:, :, :2] + prediction[:, :, 6:8] * self.variances[0] * priors[:, :, 2:],
138
+ priors[:, :, :2] + prediction[:, :, 8:10] * self.variances[0] * priors[:, :, 2:],
139
+ ), dim=-1)
140
+ return landms
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .multibox_loss import MultiBoxLoss
2
+
3
+ __all__ = ['MultiBoxLoss']
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/modules/multibox_loss.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from ...utils.box_utils import match, log_sum_exp
5
+
6
+
7
+ class MultiBoxLoss(nn.Module):
8
+ """SSD Weighted Loss Function
9
+ Compute Targets:
10
+ 1) Produce Confidence Target Indices by matching ground truth boxes
11
+ with (default) 'priorboxes' that have jaccard index > threshold parameter
12
+ (default threshold: 0.5).
13
+ 2) Produce localization target by 'encoding' variance into offsets of ground
14
+ truth boxes and their matched 'priorboxes'.
15
+ 3) Hard negative mining to filter the excessive number of negative examples
16
+ that comes with using a large number of default bounding boxes.
17
+ (default negative:positive ratio 3:1)
18
+ Objective Loss:
19
+ $L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N$
20
+ Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
21
+ weighted by α which is set to 1 by cross val.
22
+ Args:
23
+ c: class confidences,
24
+ l: predicted boxes,
25
+ g: ground truth boxes
26
+ N: number of matched default boxes
27
+ See: https://arxiv.org/pdf/1512.02325.pdf for more details.
28
+ """
29
+
30
+ def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target):
31
+ super(MultiBoxLoss, self).__init__()
32
+ self.num_classes = num_classes
33
+ self.threshold = overlap_thresh
34
+ self.background_label = bkg_label
35
+ self.encode_target = encode_target
36
+ self.use_prior_for_matching = prior_for_matching
37
+ self.do_neg_mining = neg_mining
38
+ self.negpos_ratio = neg_pos
39
+ self.neg_overlap = neg_overlap
40
+
41
+
42
+ def forward(self, predictions, priorbox, targets):
43
+ """Multibox Loss
44
+ Args:
45
+ predictions (tuple): A tuple containing loc preds, conf preds,
46
+ and prior boxes from SSD net.
47
+ conf shape: torch.size(batch_size,num_priors,num_classes)
48
+ loc shape: torch.size(batch_size,num_priors,4)
49
+ priors shape: torch.size(num_priors,4)
50
+
51
+ ground_truth (tensor): Ground truth boxes and labels for a batch,
52
+ shape: [batch_size,num_objs,5] (last idx is the label).
53
+ """
54
+
55
+ loc_data, conf_data, landm_data, aggs, thetas = predictions
56
+ num = loc_data.size(0)
57
+ num_priors = (priorbox.priors.size(0))
58
+
59
+ if aggs is not None:
60
+ stacked_target = torch.stack(targets, dim=0).squeeze(1)
61
+
62
+ pos_idx = stacked_target[:, -1] > 0
63
+ agg_ldmk = aggs[:, 6:][pos_idx]
64
+ tgt_ldmk = stacked_target[:, 4:14][pos_idx]
65
+ agg_loss_landm = F.smooth_l1_loss(agg_ldmk, tgt_ldmk, reduction='sum') / len(tgt_ldmk)
66
+
67
+ pos_idx = stacked_target[:, -1] != 0
68
+ agg_bbox = aggs[:, :4][pos_idx]
69
+ tgt_bbox = stacked_target[:, :4][pos_idx]
70
+ agg_loss_box = F.smooth_l1_loss(agg_bbox, tgt_bbox, reduction='sum') / len(tgt_bbox)
71
+
72
+ agg_cls = aggs[:, 4:6]
73
+ tgt_cls = (stacked_target[:, -1] > 0).long()
74
+ agg_loss_cls = F.cross_entropy(agg_cls, tgt_cls, reduction='sum') / len(tgt_cls)
75
+ aux_loss_dict = {
76
+ 'agg_loss_landm': agg_loss_landm,
77
+ 'agg_loss_box': agg_loss_box,
78
+ 'agg_loss_cls': agg_loss_cls
79
+ }
80
+ else:
81
+ aux_loss_dict = None
82
+
83
+ # match priors (default boxes) and ground truth boxes
84
+ loc_t = torch.Tensor(num, num_priors, 4)
85
+ landm_t = torch.Tensor(num, num_priors, 10)
86
+ conf_t = torch.LongTensor(num, num_priors)
87
+ for idx in range(num):
88
+ truths = targets[idx][:, :4].data
89
+ labels = targets[idx][:, -1].data
90
+ landms = targets[idx][:, 4:14].data
91
+ match(self.threshold, truths, priorbox, labels, landms, loc_t, conf_t, landm_t, idx)
92
+
93
+ loc_t = loc_t.cuda()
94
+ conf_t = conf_t.cuda()
95
+ landm_t = landm_t.cuda()
96
+ zeros = torch.tensor(0).cuda()
97
+ # landm Loss (Smooth L1)
98
+ # Shape: [batch,num_priors,10]
99
+ pos1 = conf_t > zeros
100
+ num_pos_landm = pos1.long().sum(1, keepdim=True)
101
+ N1 = max(num_pos_landm.data.sum().float(), 1)
102
+ pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data)
103
+ landm_p = landm_data[pos_idx1].view(-1, 10)
104
+ landm_t = landm_t[pos_idx1].view(-1, 10)
105
+ loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum')
106
+
107
+
108
+ pos = conf_t != zeros
109
+ conf_t[pos] = 1
110
+
111
+ # Localization Loss (Smooth L1)
112
+ # Shape: [batch,num_priors,4]
113
+ pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
114
+ loc_p = loc_data[pos_idx].view(-1, 4)
115
+ loc_t = loc_t[pos_idx].view(-1, 4)
116
+ loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
117
+
118
+ # Compute max conf across batch for hard negative mining
119
+ batch_conf = conf_data.view(-1, self.num_classes)
120
+ loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
121
+
122
+ # Hard Negative Mining
123
+ loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now
124
+ loss_c = loss_c.view(num, -1)
125
+ _, loss_idx = loss_c.sort(1, descending=True)
126
+ _, idx_rank = loss_idx.sort(1)
127
+ num_pos = pos.long().sum(1, keepdim=True)
128
+ num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
129
+ neg = idx_rank < num_neg.expand_as(idx_rank)
130
+
131
+ # Confidence Loss Including Positive and Negative Examples
132
+ pos_idx = pos.unsqueeze(2).expand_as(conf_data)
133
+ neg_idx = neg.unsqueeze(2).expand_as(conf_data)
134
+ conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes)
135
+ targets_weighted = conf_t[(pos+neg).gt(0)]
136
+ loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')
137
+
138
+ # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
139
+ N = max(num_pos.data.sum().float(), 1)
140
+ loss_l /= N
141
+ loss_c /= N
142
+ loss_landm /= N1
143
+
144
+ return loss_l, loss_c, loss_landm, aux_loss_dict
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/models/__init__.py ADDED
File without changes
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/models/net.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.models._utils as _utils
5
+ import torchvision.models as models
6
+ import torch.nn.functional as F
7
+ from torch.autograd import Variable
8
+
9
+ def conv_bn(inp, oup, stride = 1, leaky = 0):
10
+ return nn.Sequential(
11
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
12
+ nn.BatchNorm2d(oup),
13
+ nn.LeakyReLU(negative_slope=leaky, inplace=True)
14
+ )
15
+
16
+ def conv_bn_no_relu(inp, oup, stride):
17
+ return nn.Sequential(
18
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
19
+ nn.BatchNorm2d(oup),
20
+ )
21
+
22
+ def conv_bn1X1(inp, oup, stride, leaky=0):
23
+ return nn.Sequential(
24
+ nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
25
+ nn.BatchNorm2d(oup),
26
+ nn.LeakyReLU(negative_slope=leaky, inplace=True)
27
+ )
28
+
29
+ def conv_dw(inp, oup, stride, leaky=0.1):
30
+ return nn.Sequential(
31
+ nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
32
+ nn.BatchNorm2d(inp),
33
+ nn.LeakyReLU(negative_slope= leaky,inplace=True),
34
+
35
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
36
+ nn.BatchNorm2d(oup),
37
+ nn.LeakyReLU(negative_slope= leaky,inplace=True),
38
+ )
39
+
40
+ class SSH(nn.Module):
41
+ def __init__(self, in_channel, out_channel):
42
+ super(SSH, self).__init__()
43
+ assert out_channel % 4 == 0
44
+ leaky = 0
45
+ if (out_channel <= 64):
46
+ leaky = 0.1
47
+ self.conv3X3 = conv_bn_no_relu(in_channel, out_channel//2, stride=1)
48
+
49
+ self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1, leaky = leaky)
50
+ self.conv5X5_2 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
51
+
52
+ self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1, leaky = leaky)
53
+ self.conv7x7_3 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
54
+
55
+ def forward(self, input):
56
+ conv3X3 = self.conv3X3(input)
57
+
58
+ conv5X5_1 = self.conv5X5_1(input)
59
+ conv5X5 = self.conv5X5_2(conv5X5_1)
60
+
61
+ conv7X7_2 = self.conv7X7_2(conv5X5_1)
62
+ conv7X7 = self.conv7x7_3(conv7X7_2)
63
+
64
+ out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
65
+ out = F.relu(out)
66
+ return out
67
+
68
+ class FPN(nn.Module):
69
+ def __init__(self,in_channels_list,out_channels):
70
+ super(FPN,self).__init__()
71
+ leaky = 0
72
+ if (out_channels <= 64):
73
+ leaky = 0.1
74
+ self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1, leaky = leaky)
75
+ self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1, leaky = leaky)
76
+ self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1, leaky = leaky)
77
+
78
+ self.merge1 = conv_bn(out_channels, out_channels, leaky = leaky)
79
+ self.merge2 = conv_bn(out_channels, out_channels, leaky = leaky)
80
+
81
+ def forward(self, input):
82
+ # names = list(input.keys())
83
+ input = list(input.values())
84
+
85
+ output1 = self.output1(input[0])
86
+ output2 = self.output2(input[1])
87
+ output3 = self.output3(input[2])
88
+
89
+ up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest")
90
+ output2 = output2 + up3
91
+ output2 = self.merge2(output2)
92
+
93
+ up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest")
94
+ output1 = output1 + up2
95
+ output1 = self.merge1(output1)
96
+
97
+ out = [output1, output2, output3]
98
+ return out
99
+
100
+
101
+
102
+ class MobileNetV1(nn.Module):
103
+ def __init__(self):
104
+ super(MobileNetV1, self).__init__()
105
+ self.stage1 = nn.Sequential(
106
+ conv_bn(3, 8, 2, leaky = 0.1), # 3
107
+ conv_dw(8, 16, 1), # 7
108
+ conv_dw(16, 32, 2), # 11
109
+ conv_dw(32, 32, 1), # 19
110
+ conv_dw(32, 64, 2), # 27
111
+ conv_dw(64, 64, 1), # 43
112
+ )
113
+ self.stage2 = nn.Sequential(
114
+ conv_dw(64, 128, 2), # 43 + 16 = 59
115
+ conv_dw(128, 128, 1), # 59 + 32 = 91
116
+ conv_dw(128, 128, 1), # 91 + 32 = 123
117
+ conv_dw(128, 128, 1), # 123 + 32 = 155
118
+ conv_dw(128, 128, 1), # 155 + 32 = 187
119
+ conv_dw(128, 128, 1), # 187 + 32 = 219
120
+ )
121
+ self.stage3 = nn.Sequential(
122
+ conv_dw(128, 256, 2), # 219 +3 2 = 241
123
+ conv_dw(256, 256, 1), # 241 + 64 = 301
124
+ )
125
+
126
+ def forward(self, x):
127
+ x = self.stage1(x)
128
+ x = self.stage2(x)
129
+ x = self.stage3(x)
130
+ return x
131
+
132
+
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/models/retinaface.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models._utils as _utils
4
+ import torch.nn.functional as F
5
+ from .net import MobileNetV1 as MobileNetV1
6
+ from .net import FPN as FPN
7
+ from .net import SSH as SSH
8
+
9
+ from timm.models import mlp_mixer
10
+
11
+ class ClassHead(nn.Module):
12
+ def __init__(self,inchannels=512,num_anchors=3):
13
+ super(ClassHead,self).__init__()
14
+ self.num_anchors = num_anchors
15
+ self.conv1x1 = nn.Conv2d(inchannels,self.num_anchors*2,kernel_size=(1,1),stride=1,padding=0)
16
+
17
+ def forward(self,x):
18
+ out = self.conv1x1(x)
19
+ out = out.permute(0,2,3,1).contiguous()
20
+
21
+ return out.view(out.shape[0], -1, 2)
22
+
23
+ class BboxHead(nn.Module):
24
+ def __init__(self,inchannels=512,num_anchors=3):
25
+ super(BboxHead,self).__init__()
26
+ self.conv1x1 = nn.Conv2d(inchannels,num_anchors*4,kernel_size=(1,1),stride=1,padding=0)
27
+
28
+ def forward(self,x):
29
+ out = self.conv1x1(x)
30
+ out = out.permute(0,2,3,1).contiguous()
31
+
32
+ return out.view(out.shape[0], -1, 4)
33
+
34
+ class LandmarkHead(nn.Module):
35
+ def __init__(self,inchannels=512,num_anchors=3):
36
+ super(LandmarkHead,self).__init__()
37
+ self.conv1x1 = nn.Conv2d(inchannels,num_anchors*10,kernel_size=(1,1),stride=1,padding=0)
38
+
39
+ def forward(self,x):
40
+ out = self.conv1x1(x)
41
+ out = out.permute(0,2,3,1).contiguous()
42
+
43
+ return out.view(out.shape[0], -1, 10)
44
+
45
+ class RetinaFace(nn.Module):
46
+ def __init__(self, cfg = None, phase = 'train', use_aggregator=False):
47
+ """
48
+ :param cfg: Network related settings.
49
+ :param phase: train or test.
50
+ """
51
+ super(RetinaFace,self).__init__()
52
+ self.phase = phase
53
+ backbone = None
54
+ if cfg['name'] == 'mobilenet0.25':
55
+ backbone = MobileNetV1()
56
+ # if cfg['pretrain']:
57
+ # checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu'))
58
+ # from collections import OrderedDict
59
+ # new_state_dict = OrderedDict()
60
+ # for k, v in checkpoint['state_dict'].items():
61
+ # name = k[7:] # remove module.
62
+ # new_state_dict[name] = v
63
+ # load params
64
+ # backbone.load_state_dict(new_state_dict)
65
+ elif cfg['name'] == 'Resnet50':
66
+ import torchvision.models as models
67
+ backbone = models.resnet50(pretrained=cfg['pretrain'])
68
+
69
+ self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers'])
70
+ in_channels_stage2 = cfg['in_channel']
71
+ in_channels_list = [
72
+ in_channels_stage2 * 2,
73
+ in_channels_stage2 * 4,
74
+ in_channels_stage2 * 8,
75
+ ]
76
+ out_channels = cfg['out_channel']
77
+ self.fpn = FPN(in_channels_list,out_channels)
78
+ self.ssh1 = SSH(out_channels, out_channels)
79
+ self.ssh2 = SSH(out_channels, out_channels)
80
+ self.ssh3 = SSH(out_channels, out_channels)
81
+
82
+ self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
83
+ self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
84
+ self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
85
+
86
+ self.use_aggregator = use_aggregator
87
+ if self.use_aggregator:
88
+ modules = [mlp_mixer.MixerBlock(16, 1050) for _ in range(3)]
89
+ modules.append(nn.Linear(16, 1))
90
+ self.aggregator = nn.Sequential(*modules)
91
+ else:
92
+ self.aggregator = None
93
+
94
+ def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2):
95
+ classhead = nn.ModuleList()
96
+ for i in range(fpn_num):
97
+ classhead.append(ClassHead(inchannels,anchor_num))
98
+ return classhead
99
+
100
+ def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2):
101
+ bboxhead = nn.ModuleList()
102
+ for i in range(fpn_num):
103
+ bboxhead.append(BboxHead(inchannels,anchor_num))
104
+ return bboxhead
105
+
106
+ def _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2):
107
+ landmarkhead = nn.ModuleList()
108
+ for i in range(fpn_num):
109
+ landmarkhead.append(LandmarkHead(inchannels,anchor_num))
110
+ return landmarkhead
111
+
112
+ def forward(self, inputs, priorbox):
113
+ out = self.body(inputs)
114
+
115
+ # FPN
116
+ fpn = self.fpn(out)
117
+
118
+ # SSH
119
+ feature1 = self.ssh1(fpn[0])
120
+ feature2 = self.ssh2(fpn[1])
121
+ feature3 = self.ssh3(fpn[2])
122
+ features = [feature1, feature2, feature3]
123
+
124
+ bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
125
+ classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)],dim=1)
126
+ ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1)
127
+ if self.use_aggregator:
128
+ decoded_bbox = priorbox.decode_batch(bbox_regressions)
129
+ decoded_ldmk = priorbox.decode_landm_batch(ldm_regressions)
130
+ combined = torch.cat([decoded_bbox, classifications, decoded_ldmk], dim=2)
131
+ weight = self.aggregator(combined)
132
+ weight = F.softmax(weight, dim=1)
133
+ agg = torch.sum(weight * combined, dim=1)
134
+ theta = None
135
+ else:
136
+ agg = None
137
+ theta = None
138
+ if self.phase == 'train':
139
+ output = (bbox_regressions, classifications, ldm_regressions, agg, theta)
140
+ else:
141
+ output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions, agg, theta)
142
+ return output
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/preprocessor.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ class Preprocessor():
5
+
6
+ def __init__(self, output_size=160, padding=0.0, padding_val='zero'):
7
+ self.output_size = output_size
8
+ self.padding = padding
9
+ self.padding_val = padding_val
10
+
11
+ def preprocess_batched(self, imgs, padding_ratio_override=None):
12
+
13
+ # check img is of float
14
+ if imgs.dtype == torch.float32:
15
+ if self.padding_val == 'zero':
16
+ padding_val = -1.0
17
+ elif self.padding_val == 'mean':
18
+ padding_val = imgs.mean()
19
+ else:
20
+ raise ValueError('padding_val must be "zero" or "mean"')
21
+ elif imgs.dtype == torch.uint8:
22
+ if self.padding_val == 'zero':
23
+ padding_val = 0
24
+ elif self.padding_val == 'mean':
25
+ padding_val = imgs.mean()
26
+ else:
27
+ raise ValueError('padding_val must be "zero" or "mean"')
28
+ else:
29
+ raise ValueError('imgs.dtype must be torch.float32 or torch.uint8')
30
+
31
+ square_imgs = self.make_square_img_batched(imgs, padding_val=padding_val)
32
+
33
+ if padding_ratio_override is not None:
34
+ padding = padding_ratio_override
35
+ else:
36
+ padding = self.padding
37
+ padded_imgs = self.make_padded_img_batched(square_imgs, padding=padding, padding_val=padding_val)
38
+
39
+ size=(self.output_size, self.output_size)
40
+ if imgs.dtype == torch.float32:
41
+ resized_imgs = F.interpolate(padded_imgs, size=size, mode='bilinear', align_corners=True)
42
+ elif imgs.dtype == torch.uint8:
43
+ padded_imgs = padded_imgs.to(torch.float32)
44
+ resized_imgs = F.interpolate(padded_imgs, size=size, mode='bilinear', align_corners=True)
45
+ resized_imgs = torch.clip(resized_imgs, 0, 255)
46
+ resized_imgs = resized_imgs.to(torch.uint8)
47
+ else:
48
+ raise ValueError('imgs.dtype must be torch.float32 or torch.uint8')
49
+ return resized_imgs
50
+
51
+
52
+ def make_square_img_batched(self, imgs, padding_val):
53
+ assert imgs.ndim == 4
54
+ # squarify the image
55
+ h, w = imgs.shape[2:]
56
+ if h > w:
57
+ diff = (h - w)
58
+ pad_left = diff // 2
59
+ pad_right = diff - pad_left
60
+ imgs = F.pad(imgs, (pad_left, pad_right, 0, 0), value=padding_val)
61
+ elif w > h:
62
+ diff = (w - h)
63
+ pad_top = diff // 2
64
+ pad_bottom = diff - pad_top
65
+ imgs = F.pad(imgs, (0, 0, pad_top, pad_bottom), value=padding_val)
66
+ assert imgs.shape[2] == imgs.shape[3]
67
+ return imgs
68
+
69
+
70
+ def make_padded_img_batched(self, imgs, padding, padding_val):
71
+ if padding == 0:
72
+ return imgs
73
+ assert imgs.ndim == 4
74
+
75
+
76
+ # pad the image
77
+ h, w = imgs.shape[2:]
78
+ pad_h = int(h * padding)
79
+ pad_w = int(w * padding)
80
+ imgs = F.pad(imgs, (pad_w, pad_w, pad_h, pad_h), value=padding_val)
81
+ return imgs
82
+
83
+
84
+ def __call__(self, input, padding_ratio_override=None):
85
+ if input.ndim == 3:
86
+ assert input.shape[0] == 3
87
+ batch_input = input.unsqueeze(0)
88
+ return self.preprocess_batched(batch_input, padding_ratio_override=padding_ratio_override)[0]
89
+ elif input.ndim == 4:
90
+ assert input.shape[1] == 3
91
+ return self.preprocess_batched(input, padding_ratio_override=padding_ratio_override)
92
+ else:
93
+ raise ValueError(f'Invalid input shape: {input.shape}')
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/utils/box_utils.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def point_form(boxes):
6
+ """ Convert prior_boxes to (xmin, ymin, xmax, ymax)
7
+ representation for comparison to point form ground truth data.
8
+ Args:
9
+ boxes: (tensor) center-size default boxes from priorbox layers.
10
+ Return:
11
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
12
+ """
13
+ return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin
14
+ boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax
15
+
16
+
17
+ def center_size(boxes):
18
+ """ Convert prior_boxes to (cx, cy, w, h)
19
+ representation for comparison to center-size form ground truth data.
20
+ Args:
21
+ boxes: (tensor) point_form boxes
22
+ Return:
23
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
24
+ """
25
+ return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy
26
+ boxes[:, 2:] - boxes[:, :2], 1) # w, h
27
+
28
+
29
+ def intersect(box_a, box_b):
30
+ """ We resize both tensors to [A,B,2] without new malloc:
31
+ [A,2] -> [A,1,2] -> [A,B,2]
32
+ [B,2] -> [1,B,2] -> [A,B,2]
33
+ Then we compute the area of intersect between box_a and box_b.
34
+ Args:
35
+ box_a: (tensor) bounding boxes, Shape: [A,4].
36
+ box_b: (tensor) bounding boxes, Shape: [B,4].
37
+ Return:
38
+ (tensor) intersection area, Shape: [A,B].
39
+ """
40
+ A = box_a.size(0)
41
+ B = box_b.size(0)
42
+ max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
43
+ box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
44
+ min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
45
+ box_b[:, :2].unsqueeze(0).expand(A, B, 2))
46
+ inter = torch.clamp((max_xy - min_xy), min=0)
47
+ return inter[:, :, 0] * inter[:, :, 1]
48
+
49
+
50
+ def jaccard(box_a, box_b):
51
+ """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
52
+ is simply the intersection over union of two boxes. Here we operate on
53
+ ground truth boxes and default boxes.
54
+ E.g.:
55
+ A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
56
+ Args:
57
+ box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
58
+ box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
59
+ Return:
60
+ jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
61
+ """
62
+ inter = intersect(box_a, box_b)
63
+ area_a = ((box_a[:, 2]-box_a[:, 0]) *
64
+ (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
65
+ area_b = ((box_b[:, 2]-box_b[:, 0]) *
66
+ (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
67
+ union = area_a + area_b - inter
68
+ return inter / union # [A,B]
69
+
70
+
71
+ def matrix_iou(a, b):
72
+ """
73
+ return iou of a and b, numpy version for data augenmentation
74
+ """
75
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
76
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
77
+
78
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
79
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
80
+ area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
81
+ return area_i / (area_a[:, np.newaxis] + area_b - area_i)
82
+
83
+
84
+ def matrix_iof(a, b):
85
+ """
86
+ return iof of a and b, numpy version for data augenmentation
87
+ """
88
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
89
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
90
+
91
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
92
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
93
+ return area_i / np.maximum(area_a[:, np.newaxis], 1)
94
+
95
+
96
+ def match(threshold, truths, priorbox, labels, landms, loc_t, conf_t, landm_t, idx):
97
+ """Match each prior box with the ground truth box of the highest jaccard
98
+ overlap, encode the bounding boxes, then return the matched indices
99
+ corresponding to both confidence and location preds.
100
+ Args:
101
+ threshold: (float) The overlap threshold used when mathing boxes.
102
+ truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
103
+ priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
104
+ variances: (tensor) Variances corresponding to each prior coord,
105
+ Shape: [num_priors, 4].
106
+ labels: (tensor) All the class labels for the image, Shape: [num_obj].
107
+ landms: (tensor) Ground truth landms, Shape [num_obj, 10].
108
+ loc_t: (tensor) Tensor to be filled w/ endcoded location targets.
109
+ conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
110
+ landm_t: (tensor) Tensor to be filled w/ endcoded landm targets.
111
+ idx: (int) current batch index
112
+ Return:
113
+ The matched indices corresponding to 1)location 2)confidence 3)landm preds.
114
+ """
115
+
116
+
117
+ # jaccard index
118
+ overlaps = jaccard(
119
+ truths,
120
+ point_form(priorbox.priors)
121
+ )
122
+ # (Bipartite Matching)
123
+ # [1,num_objects] best prior for each ground truth
124
+ best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
125
+
126
+ # ignore hard gt
127
+ valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
128
+ best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
129
+ if best_prior_idx_filter.shape[0] <= 0:
130
+ loc_t[idx] = 0
131
+ conf_t[idx] = 0
132
+ return
133
+
134
+ # [1,num_priors] best ground truth for each prior
135
+ best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
136
+ best_truth_idx.squeeze_(0)
137
+ best_truth_overlap.squeeze_(0)
138
+ best_prior_idx.squeeze_(1)
139
+ best_prior_idx_filter.squeeze_(1)
140
+ best_prior_overlap.squeeze_(1)
141
+ best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
142
+ # TODO refactor: index best_prior_idx with long tensor
143
+ # ensure every gt matches with its prior of max overlap
144
+ for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes
145
+ best_truth_idx[best_prior_idx[j]] = j
146
+ matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来
147
+ conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来
148
+ conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本
149
+ loc = priorbox.encode(matches)
150
+
151
+ matches_landm = landms[best_truth_idx]
152
+ landm = priorbox.encode_landm(matches_landm)
153
+ loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
154
+ conf_t[idx] = conf # [num_priors] top class label for each prior
155
+ landm_t[idx] = landm
156
+
157
+
158
+
159
+ def log_sum_exp(x):
160
+ """Utility function for computing log_sum_exp while determining
161
+ This will be used to determine unaveraged confidence loss across
162
+ all examples in a batch.
163
+ Args:
164
+ x (Variable(tensor)): conf_preds from conf layers
165
+ """
166
+ x_max = x.data.max()
167
+ return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max
168
+
169
+
170
+ # Original author: Francisco Massa:
171
+ # https://github.com/fmassa/object-detection.torch
172
+ # Ported to PyTorch by Max deGroot (02/01/2017)
173
+ def nms(boxes, scores, overlap=0.5, top_k=200):
174
+ """Apply non-maximum suppression at test time to avoid detecting too many
175
+ overlapping bounding boxes for a given object.
176
+ Args:
177
+ boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
178
+ scores: (tensor) The class predscores for the img, Shape:[num_priors].
179
+ overlap: (float) The overlap thresh for suppressing unnecessary boxes.
180
+ top_k: (int) The Maximum number of box preds to consider.
181
+ Return:
182
+ The indices of the kept boxes with respect to num_priors.
183
+ """
184
+
185
+ keep = torch.Tensor(scores.size(0)).fill_(0).long()
186
+ if boxes.numel() == 0:
187
+ return keep
188
+ x1 = boxes[:, 0]
189
+ y1 = boxes[:, 1]
190
+ x2 = boxes[:, 2]
191
+ y2 = boxes[:, 3]
192
+ area = torch.mul(x2 - x1, y2 - y1)
193
+ v, idx = scores.sort(0) # sort in ascending order
194
+ # I = I[v >= 0.01]
195
+ idx = idx[-top_k:] # indices of the top-k largest vals
196
+ xx1 = boxes.new()
197
+ yy1 = boxes.new()
198
+ xx2 = boxes.new()
199
+ yy2 = boxes.new()
200
+ w = boxes.new()
201
+ h = boxes.new()
202
+
203
+ # keep = torch.Tensor()
204
+ count = 0
205
+ while idx.numel() > 0:
206
+ i = idx[-1] # index of current largest val
207
+ # keep.append(i)
208
+ keep[count] = i
209
+ count += 1
210
+ if idx.size(0) == 1:
211
+ break
212
+ idx = idx[:-1] # remove kept element from view
213
+ # load bboxes of next highest vals
214
+ torch.index_select(x1, 0, idx, out=xx1)
215
+ torch.index_select(y1, 0, idx, out=yy1)
216
+ torch.index_select(x2, 0, idx, out=xx2)
217
+ torch.index_select(y2, 0, idx, out=yy2)
218
+ # store element-wise max with next highest score
219
+ xx1 = torch.clamp(xx1, min=x1[i])
220
+ yy1 = torch.clamp(yy1, min=y1[i])
221
+ xx2 = torch.clamp(xx2, max=x2[i])
222
+ yy2 = torch.clamp(yy2, max=y2[i])
223
+ w.resize_as_(xx2)
224
+ h.resize_as_(yy2)
225
+ w = xx2 - xx1
226
+ h = yy2 - yy1
227
+ # check sizes of xx1 and xx2.. after each iteration
228
+ w = torch.clamp(w, min=0.0)
229
+ h = torch.clamp(h, min=0.0)
230
+ inter = w*h
231
+ # IoU = i / (area(a) + area(b) - i)
232
+ rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
233
+ union = (rem_areas - inter) + area[i]
234
+ IoU = inter/union # store result in iou
235
+ # keep only elements with an IoU <= overlap
236
+ idx = idx[IoU.le(overlap)]
237
+ return keep, count
238
+
239
+
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/utils/model_utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def remove_prefix(state_dict, prefix):
4
+ ''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
5
+ print('remove prefix \'{}\''.format(prefix))
6
+ f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
7
+ return {f(key): value for key, value in state_dict.items()}
8
+
9
+ def check_keys(model, pretrained_state_dict):
10
+ ckpt_keys = set(pretrained_state_dict.keys())
11
+ model_keys = set(model.state_dict().keys())
12
+ used_pretrained_keys = model_keys & ckpt_keys
13
+ unused_pretrained_keys = ckpt_keys - model_keys
14
+ missing_keys = model_keys - ckpt_keys
15
+ print('Missing keys:{}'.format(len(missing_keys)))
16
+ print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
17
+ print('Used keys:{}'.format(len(used_pretrained_keys)))
18
+ assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
19
+ return True
20
+
21
+ def load_model(model, pretrained_path, load_to_cpu):
22
+ print('Loading pretrained model from {}'.format(pretrained_path))
23
+ if load_to_cpu:
24
+ pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
25
+ else:
26
+ device = torch.cuda.current_device()
27
+ pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
28
+ if "state_dict" in pretrained_dict.keys():
29
+ pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
30
+ else:
31
+ pretrained_dict = remove_prefix(pretrained_dict, 'module.')
32
+ check_keys(model, pretrained_dict)
33
+ model.load_state_dict(pretrained_dict, strict=False)
34
+ return model
35
+
36
+
cvlface/research/recognition/code/run_v1/aligners/none/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..base import BaseAligner
2
+
3
+
4
+ class NoneAligner(BaseAligner):
5
+ def __init__(self, config):
6
+ super().__init__()
7
+ self.config = config
8
+
9
+ @classmethod
10
+ def from_config(cls, aligner_config):
11
+ return cls(aligner_config)
12
+
13
+ def make_train_transform(self):
14
+ return lambda x:x
15
+
16
+ def make_test_transform(self):
17
+ return lambda x:x
18
+
19
+ def forward(self, x):
20
+ return x
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/__init__.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..base import BaseAligner
2
+ from torchvision import transforms
3
+ from .retinaface import get_landmark_predictor, get_preprocessor
4
+ from . import aligner_helper
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+
9
+ class RetinaFaceAligner(BaseAligner):
10
+
11
+ """
12
+ A non-differentiable face aligner that aligns the image with one face to a canonical position.
13
+ The aligner is based on the following paper:
14
+
15
+ ```
16
+ @inproceedings{deng2020retinaface,
17
+ title={Retinaface: Single-shot multi-level face localisation in the wild},
18
+ author={Deng, Jiankang and Guo, Jia and Ververas, Evangelos and Kotsia, Irene and Zafeiriou, Stefanos},
19
+ booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
20
+ pages={5203--5212},
21
+ year={2020}
22
+ }
23
+ ```
24
+ """
25
+
26
+ def __init__(self, net, prior_box, preprocessor, config):
27
+ super(RetinaFaceAligner, self).__init__()
28
+ self.net = net
29
+ self.prior_box = prior_box
30
+ self.preprocessor = preprocessor
31
+ self.config = config
32
+
33
+ @classmethod
34
+ def from_config(cls, config):
35
+ net, prior_box = get_landmark_predictor(network=config.arch,
36
+ input_size=config.input_size)
37
+
38
+ preprocessor = get_preprocessor(output_size=config.input_size,
39
+ padding=config.input_padding_ratio,
40
+ padding_val=config.input_padding_val)
41
+ if config.freeze:
42
+ for param in net.parameters():
43
+ param.requires_grad = False
44
+ model = cls(net, prior_box, preprocessor, config)
45
+ model.eval()
46
+ return model
47
+
48
+ def forward(self, x, padding_ratio_override=None):
49
+
50
+ # input size check
51
+ assert x.shape[1] == 3
52
+ assert x.ndim == 4
53
+ assert isinstance(x, torch.Tensor)
54
+ is_square = x.shape[2] == x.shape[3]
55
+
56
+ x = self.preprocessor(x, padding_ratio_override=padding_ratio_override)
57
+ assert self.prior_box.image_size == x.shape[2:]
58
+
59
+ # make image into BGR
60
+ x_bgr = x.flip(1)
61
+ input_img = normalize_for_net(unnormalize(x_bgr))
62
+
63
+ result = self.net(input_img, self.prior_box)
64
+ batch_loc, batch_conf, batch_landms = result
65
+ batch_loc = torch.split(batch_loc, 1, dim=0)
66
+ batch_conf = torch.split(batch_conf, 1, dim=0)
67
+ batch_landms = torch.split(batch_landms, 1, dim=0)
68
+
69
+ nms_ldmks = []
70
+ nms_scores = []
71
+ nms_bbox = []
72
+ for loc, conf, landms, in zip(batch_loc, batch_conf, batch_landms):
73
+ dets = postprocess(self.prior_box, loc, conf, landms, confidence_threshold=0.0, nms_threshold=0.4)
74
+ bbox, score, ldmks = parse_one_det_result(dets)
75
+ ldmks = ldmks / np.array( [self.prior_box.image_size[0], self.prior_box.image_size[1]] * 5)
76
+ nms_ldmks.append(ldmks)
77
+ nms_scores.append(score)
78
+ nms_bbox.append(bbox)
79
+
80
+ orig_pred_ldmks = torch.from_numpy(np.array(nms_ldmks)).to(self.device).float()
81
+ score = torch.from_numpy(np.array(nms_scores)).to(self.device).float().unsqueeze(-1)
82
+ bbox = torch.from_numpy(np.array(nms_bbox)).to(self.device).float()
83
+
84
+
85
+ reference_ldmk = aligner_helper.reference_landmark()
86
+ input_size = self.config.input_size
87
+ output_size = self.config.output_size
88
+ cv2_tfms = aligner_helper.get_cv2_affine_from_landmark(orig_pred_ldmks, reference_ldmk, input_size, input_size)
89
+ thetas = aligner_helper.cv2_param_to_torch_theta(cv2_tfms, input_size, input_size, output_size, output_size)
90
+ thetas = thetas.to(orig_pred_ldmks.device)
91
+
92
+ output_size = torch.Size((len(thetas), 3, output_size, output_size))
93
+ grid = F.affine_grid(thetas, output_size, align_corners=True)
94
+ aligned_x = F.grid_sample(x + 1, grid, align_corners=True) - 1 # +1, -1 for making padding pixel 0
95
+ aligned_ldmks = aligner_helper.adjust_ldmks(orig_pred_ldmks.view(-1, 5, 2), thetas)
96
+
97
+ orig_pred_ldmks = orig_pred_ldmks.view(-1, 5, 2)
98
+ # bbox (xmin, ymin, xmax, ymax)
99
+ normalized_bbox = bbox / torch.tensor([[input_img.size(3), input_img.size(2)] * 2]).to(bbox.device)
100
+
101
+
102
+ if padding_ratio_override is None:
103
+ padding_ratio = self.preprocessor.padding
104
+ else:
105
+ padding_ratio = padding_ratio_override
106
+ if padding_ratio > 0:
107
+ # unpad the landmark so that it is in the original image coordinate
108
+ scale = 1 / (1 + (2 * padding_ratio))
109
+ pad_inv_theta = torch.from_numpy(np.array([[1 / scale, 0, 0], [0, 1 / scale, 0]]))
110
+ pad_inv_theta = pad_inv_theta.unsqueeze(0).float().to(self.device).repeat(orig_pred_ldmks.size(0), 1, 1)
111
+ unpad_ldmk_pred = torch.concat([orig_pred_ldmks.view(-1, 5, 2),
112
+ torch.ones((orig_pred_ldmks.size(0), 5, 1)).to(self.device)], dim=-1)
113
+ unpad_ldmk_pred = (((unpad_ldmk_pred) * 2 - 1) @ pad_inv_theta.mT) / 2 + 0.5
114
+ unpad_ldmk_pred = unpad_ldmk_pred.view(orig_pred_ldmks.size(0), -1).detach()
115
+ unpad_ldmk_pred = unpad_ldmk_pred.view(-1, 5, 2)
116
+ if not is_square:
117
+ unpad_ldmk_pred = None # cannot use this if the input is not square becaouse preprocessor changes input
118
+ normalized_bbox = None # cannot use this if the input is not square becaouse preprocessor changes input
119
+ return aligned_x, unpad_ldmk_pred, aligned_ldmks, score, thetas, normalized_bbox
120
+
121
+ if not is_square:
122
+ orig_pred_ldmks = None # cannot use this if the input is not square becaouse preprocessor changes input
123
+ normalized_bbox = None # cannot use this if the input is not square becaouse preprocessor changes input
124
+ return aligned_x, orig_pred_ldmks, aligned_ldmks, score, thetas, normalized_bbox
125
+
126
+ def make_train_transform(self):
127
+ transform = transforms.Compose([
128
+ transforms.ToTensor(),
129
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
130
+ ])
131
+ return transform
132
+
133
+ def make_test_transform(self):
134
+ transform = transforms.Compose([
135
+ transforms.ToTensor(),
136
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
137
+ ])
138
+ return transform
139
+
140
+
141
+ def normalize(image):
142
+ image = image / 255.
143
+ image = (image - 0.5) / 0.5
144
+ return image
145
+
146
+ def unnormalize(image):
147
+ image = image * 0.5 + 0.5
148
+ image = image * 255.
149
+ return image
150
+
151
+ def normalize_for_net(bgr_image_0_255):
152
+ # bgr_image = cv2.imread(image_path, cv2.IMREAD_COLOR)
153
+ return bgr_image_0_255 - torch.tensor([104, 117, 123])[None, :, None, None].to(bgr_image_0_255.device)
154
+
155
+
156
+ def postprocess(priorbox, loc, conf, landms, confidence_threshold, nms_threshold):
157
+
158
+ device = loc.device
159
+ im_height, im_width = priorbox.image_size
160
+
161
+ scale = torch.Tensor([im_width, im_height, im_width, im_height])
162
+ scale = scale.to(device)
163
+
164
+ boxes = priorbox.decode(loc.data.squeeze(0))
165
+ boxes = boxes * scale
166
+ boxes = boxes.cpu().numpy()
167
+ scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
168
+ landms = priorbox.decode_landm(landms.data.squeeze(0))
169
+ scale1 = torch.Tensor([im_width, im_height, im_width, im_height,
170
+ im_width, im_height, im_width, im_height,
171
+ im_width, im_height])
172
+ scale1 = scale1.to(device)
173
+ landms = landms * scale1
174
+ landms = landms.cpu().numpy()
175
+
176
+ # ignore low scores
177
+ inds = np.where(scores > confidence_threshold)[0]
178
+ if len(inds) == 0:
179
+ inds = np.where(scores >= 0)[0]
180
+ boxes = boxes[inds]
181
+ landms = landms[inds]
182
+ scores = scores[inds]
183
+
184
+ # keep top-K before NMS
185
+ order = scores.argsort()[::-1]
186
+ # order = scores.argsort()[::-1][:args.top_k]
187
+ boxes = boxes[order]
188
+ landms = landms[order]
189
+ scores = scores[order]
190
+
191
+ # do NMS
192
+ dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
193
+ keep = py_cpu_nms(dets, nms_threshold)
194
+ # keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
195
+ dets = dets[keep, :]
196
+ landms = landms[keep]
197
+
198
+ # keep top-K faster NMS
199
+ # dets = dets[:args.keep_top_k, :]
200
+ # landms = landms[:args.keep_top_k, :]
201
+
202
+ dets = np.concatenate((dets, landms), axis=1)
203
+ return dets
204
+
205
+
206
+ def py_cpu_nms(dets,
207
+ thresh):
208
+ """
209
+ Pure Python NMS baseline.
210
+ """
211
+ x1 = dets[:, 0]
212
+ y1 = dets[:, 1]
213
+ x2 = dets[:, 2]
214
+ y2 = dets[:, 3]
215
+ scores = dets[:, 4]
216
+
217
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
218
+ order = scores.argsort()[::-1]
219
+
220
+ keep = []
221
+ while order.size > 0:
222
+ i = order[0]
223
+ keep.append(i)
224
+ xx1 = np.maximum(x1[i], x1[order[1:]])
225
+ yy1 = np.maximum(y1[i], y1[order[1:]])
226
+ xx2 = np.minimum(x2[i], x2[order[1:]])
227
+ yy2 = np.minimum(y2[i], y2[order[1:]])
228
+
229
+ w = np.maximum(0.0, xx2 - xx1 + 1)
230
+ h = np.maximum(0.0, yy2 - yy1 + 1)
231
+ inter = w * h
232
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
233
+
234
+ inds = np.where(ovr <= thresh)[0]
235
+ order = order[inds + 1]
236
+
237
+ return keep
238
+
239
+
240
+ def parse_one_det_result(dets):
241
+ dets_sorted = dets[dets[:, 4].argsort()[::-1]]
242
+ result = dets_sorted[0]
243
+ bbox = result[:4]
244
+ score = result[4]
245
+ ldmks = result[5:]
246
+ return bbox, score, ldmks
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/aligner_helper.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ from skimage import transform as trans
5
+ import cv2
6
+
7
+
8
+ def split_network_output(align_out):
9
+ anchor_bbox_pred, anchor_cls_pred, anchor_ldmk_pred, merged, _ = align_out
10
+ bbox, cls, ldmk = torch.split(merged, [4, 2, 10], dim=1)
11
+ return ldmk, bbox, cls
12
+
13
+
14
+ def get_cv2_affine_from_landmark(ldmks, reference_ldmk, image_width, image_height, ):
15
+ assert ldmks.ndim == 2 # batchdim
16
+ assert ldmks.shape[1] == 10
17
+ assert isinstance(ldmks, torch.Tensor)
18
+
19
+ assert reference_ldmk.ndim == 2
20
+ assert reference_ldmk.shape[0] == 5
21
+ assert reference_ldmk.shape[1] == 2
22
+ assert isinstance(reference_ldmk, np.ndarray)
23
+
24
+ to_img_size = np.array([[[image_width, image_height]]])
25
+ ldmks = ldmks.view(ldmks.shape[0], 5, 2).detach().cpu().numpy()
26
+ ldmks = ldmks * to_img_size
27
+ transforms = []
28
+ for ldmk in ldmks:
29
+ tform = trans.SimilarityTransform()
30
+ tform.estimate(ldmk, reference_ldmk)
31
+ M = tform.params[0:2, :]
32
+ transforms.append(M)
33
+ transforms = np.stack(transforms, axis=0)
34
+ return transforms
35
+
36
+
37
+ def cv2_param_to_torch_theta(cv2_tfms, image_width, image_height, output_width, output_height):
38
+ # https://github.com/wuneng/WarpAffine2GridSample
39
+ """4.Affine Transformation Matrix to theta"""
40
+ assert cv2_tfms.ndim == 3 # N, 2, 3
41
+ assert cv2_tfms.shape[1] == 2
42
+ assert cv2_tfms.shape[2] == 3
43
+
44
+ srcs = np.array([[0, 0], [0, 1], [1, 1]], dtype=np.float32)
45
+ srcs = np.expand_dims(srcs, axis=0).repeat(cv2_tfms.shape[0], axis=0)
46
+ dsts = np.matmul(srcs, cv2_tfms[:, :, :2].transpose(0, 2, 1)) + cv2_tfms[:, :, 2:3].transpose(0, 2, 1)
47
+
48
+ # normalize to [-1, 1]
49
+ srcs = srcs / np.array([[[image_width, image_height]]]) * 2 - 1
50
+ dsts = dsts / np.array([[[output_width, output_height]]]) * 2 - 1
51
+
52
+ thetas = []
53
+ for src, dst in zip(srcs, dsts):
54
+ theta = trans.estimate_transform("affine", src=dst, dst=src).params[:2]
55
+ thetas.append(theta)
56
+ thetas = np.stack(thetas, axis=0)
57
+ thetas = torch.from_numpy(thetas).float()
58
+ return thetas
59
+
60
+
61
+ def adjust_ldmks(ldmks, thetas):
62
+ inv_thetas = inv_matrix(thetas).to(ldmks.device).float()
63
+ _ldmks = torch.cat([ldmks, torch.ones((ldmks.shape[0], 5, 1)).to(ldmks.device)], dim=2)
64
+ ldmk_aligned = (((_ldmks) * 2 - 1) @ inv_thetas.permute(0,2,1)) / 2 + 0.5
65
+ return ldmk_aligned
66
+
67
+
68
+ def inv_matrix(theta):
69
+ # torch batched version
70
+ assert theta.ndim == 3
71
+ a, b, t1 = theta[:, 0,0], theta[:, 0,1], theta[:, 0,2]
72
+ c, d, t2 = theta[:, 1,0], theta[:, 1,1], theta[:, 1,2]
73
+ det = a * d - b * c
74
+ inv_det = 1.0 / det
75
+ inv_mat = torch.stack([
76
+ torch.stack([d * inv_det, -b * inv_det, (b * t2 - d * t1) * inv_det], dim=1),
77
+ torch.stack([-c * inv_det, a * inv_det, (c * t1 - a * t2) * inv_det], dim=1)
78
+ ], dim=1)
79
+ return inv_mat
80
+
81
+ def reference_landmark():
82
+ return np.array([[38.29459953, 51.69630051],
83
+ [73.53179932, 51.50139999],
84
+ [56.02519989, 71.73660278],
85
+ [41.54930115, 92.3655014],
86
+ [70.72990036, 92.20410156]])
87
+
88
+
89
+ def draw_ldmk(img, ldmk):
90
+ if ldmk is None:
91
+ return img
92
+ colors = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255), (255, 0, 255)]
93
+ img = img.copy()
94
+ for i in range(5):
95
+ color = colors[i]
96
+ cv2.circle(img, (int(ldmk[i*2] * img.shape[1]), int(ldmk[i*2+1] * img.shape[0])), 1, color, 4)
97
+ return img
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .models.retinaface import RetinaFace
2
+ from .utils.model_utils import load_model
3
+ from .config import cfg_mnet, cfg_re50
4
+ from .layers.functions.prior_box import PriorBox
5
+ from .preprocessor import Preprocessor
6
+
7
+ def get_landmark_predictor(network='mobile0.25', input_size=160):
8
+
9
+ cfg = None
10
+ if network == "mobile0.25":
11
+ cfg = cfg_mnet
12
+ elif network == "resnet50":
13
+ cfg = cfg_re50
14
+ net = RetinaFace(cfg=cfg, phase = 'test')
15
+ priorbox = PriorBox(image_size=(input_size, input_size),
16
+ # min_sizes=[[64, 80], [96, 112], [128, 144]],
17
+ min_sizes=[[16, 32], [64, 128], [256, 512]],
18
+ steps=[8, 16, 32],
19
+ clip=False,
20
+ variances=[0.1, 0.2],)
21
+
22
+ # aligner = Aligner(net, priorbox, input_size, output_size=output_size)
23
+ # return aligner
24
+ return net, priorbox
25
+
26
+
27
+ def get_preprocessor(output_size=160, padding=0.0, padding_val='zero'):
28
+ return Preprocessor(output_size=output_size, padding=padding, padding_val=padding_val)
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/config.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+
3
+ cfg_mnet = {
4
+ 'name': 'mobilenet0.25',
5
+ 'pretrain': True,
6
+ 'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3},
7
+ 'in_channel': 32,
8
+ 'out_channel': 64
9
+ }
10
+
11
+ cfg_re50 = {
12
+ 'name': 'Resnet50',
13
+ 'pretrain': True,
14
+ 'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3},
15
+ 'in_channel': 256,
16
+ 'out_channel': 256
17
+ }
18
+
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .functions import *
2
+ from .modules import *
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/functions/prior_box.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from itertools import product as product
3
+ from math import ceil
4
+
5
+
6
+ class PriorBox(object):
7
+
8
+ def __init__(self,
9
+ image_size,
10
+ min_sizes=[[64, 80], [96, 112], [128, 144]],
11
+ steps=[8,16,32],
12
+ clip=False,
13
+ variances=[0.1, 0.2],
14
+ ):
15
+ super(PriorBox, self).__init__()
16
+ self.min_sizes = min_sizes
17
+ self.steps = steps
18
+ self.clip = clip
19
+ self.variances = variances
20
+ self.image_size = image_size
21
+ self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps]
22
+ with torch.no_grad():
23
+ self.priors = self.forward()
24
+
25
+ def forward(self):
26
+ anchors = []
27
+ for k, f in enumerate(self.feature_maps):
28
+ min_sizes = self.min_sizes[k]
29
+ for i, j in product(range(f[0]), range(f[1])):
30
+ for min_size in min_sizes:
31
+ s_kx = min_size / self.image_size[1]
32
+ s_ky = min_size / self.image_size[0]
33
+ dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
34
+ dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
35
+ for cy, cx in product(dense_cy, dense_cx):
36
+ anchors += [cx, cy, s_kx, s_ky]
37
+
38
+ # back to torch land
39
+ output = torch.Tensor(anchors).view(-1, 4)
40
+ # import pandas as pd
41
+ # pd.DataFrame(output.numpy()).to_csv('/mckim/temp/temp.csv')
42
+ if self.clip:
43
+ output.clamp_(max=1, min=0)
44
+ return output
45
+
46
+ def encode(self, matched):
47
+ """Encode the variances from the priorbox layers into the ground truth boxes
48
+ we have matched (based on jaccard overlap) with the prior boxes.
49
+ """
50
+ self.priors = self.priors.to(matched.device)
51
+
52
+ # dist b/t match center and prior's center
53
+ g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - self.priors[:, :2]
54
+ # encode variance
55
+ g_cxcy /= (self.variances[0] * self.priors[:, 2:])
56
+ # match wh / prior wh
57
+ g_wh = (matched[:, 2:] - matched[:, :2]) / self.priors[:, 2:]
58
+ g_wh = torch.log(g_wh) / self.variances[1]
59
+ # return target for smooth_l1_loss
60
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
61
+
62
+ def encode_landm(self, matched):
63
+ """Encode the variances from the priorbox layers into the ground truth boxes
64
+ we have matched (based on jaccard overlap) with the prior boxes.
65
+ """
66
+ self.priors = self.priors.to(matched.device)
67
+
68
+ # dist b/t match center and prior's center
69
+ matched = torch.reshape(matched, (matched.size(0), 5, 2))
70
+ priors_cx = self.priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
71
+ priors_cy = self.priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
72
+ priors_w = self.priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
73
+ priors_h = self.priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
74
+ priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
75
+ g_cxcy = matched[:, :, :2] - priors[:, :, :2]
76
+ # encode variance
77
+ g_cxcy /= (self.variances[0] * priors[:, :, 2:])
78
+ # g_cxcy /= priors[:, :, 2:]
79
+ g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
80
+ # return target for smooth_l1_loss
81
+ return g_cxcy
82
+
83
+
84
+ # Adapted from https://github.com/Hakuyume/chainer-ssd
85
+ def decode(self, loc):
86
+ """Decode locations from predictions using priors to undo
87
+ the encoding we did for offset regression at train time.
88
+ """
89
+ self.priors = self.priors.to(loc.device)
90
+
91
+ boxes = torch.cat((
92
+ self.priors[:, :2] + loc[:, :2] * self.variances[0] * self.priors[:, 2:],
93
+ self.priors[:, 2:] * torch.exp(loc[:, 2:] * self.variances[1])), 1)
94
+ boxes[:, :2] -= boxes[:, 2:] / 2
95
+ boxes[:, 2:] += boxes[:, :2]
96
+ return boxes
97
+
98
+ def decode_landm(self, pre):
99
+ """Decode landm from predictions using priors to undo
100
+ the encoding we did for offset regression at train time.
101
+ """
102
+ self.priors = self.priors.to(pre.device)
103
+ landms = torch.cat((self.priors[:, :2] + pre[:, :2] * self.variances[0] * self.priors[:, 2:],
104
+ self.priors[:, :2] + pre[:, 2:4] * self.variances[0] * self.priors[:, 2:],
105
+ self.priors[:, :2] + pre[:, 4:6] * self.variances[0] * self.priors[:, 2:],
106
+ self.priors[:, :2] + pre[:, 6:8] * self.variances[0] * self.priors[:, 2:],
107
+ self.priors[:, :2] + pre[:, 8:10] * self.variances[0] * self.priors[:, 2:],
108
+ ), dim=1)
109
+ return landms
110
+
111
+
112
+ def decode_batch(self, loc):
113
+ """Decode locations from predictions using priors to undo
114
+ the encoding we did for offset regression at train time.
115
+ """
116
+ self.priors = self.priors.to(loc.device)
117
+ assert loc.ndim == 3
118
+ priors = self.priors.unsqueeze(0).expand(loc.size(0), -1, -1)
119
+ boxes = torch.cat((
120
+ priors[:, :, :2] + loc[:, :, :2] * self.variances[0] * priors[:, :, 2:],
121
+ priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * self.variances[1])), -1)
122
+ boxes[:, :, :2] -= boxes[:, :, 2:] / 2
123
+ boxes[:, :, 2:] += boxes[:, :, :2]
124
+ return boxes
125
+
126
+
127
+ def decode_landm_batch(self, prediction):
128
+ """Decode landm from prediction using priors to undo
129
+ the encoding we did for offset regression at train time.
130
+ """
131
+ assert prediction.ndim == 3
132
+ self.priors = self.priors.to(prediction.device)
133
+ priors = self.priors.unsqueeze(0).expand(prediction.size(0), -1, -1)
134
+ landms = torch.cat((priors[:, :, :2] + prediction[:, :, :2] * self.variances[0] * priors[:, :, 2:],
135
+ priors[:, :, :2] + prediction[:, :, 2:4] * self.variances[0] * priors[:, :, 2:],
136
+ priors[:, :, :2] + prediction[:, :, 4:6] * self.variances[0] * priors[:, :, 2:],
137
+ priors[:, :, :2] + prediction[:, :, 6:8] * self.variances[0] * priors[:, :, 2:],
138
+ priors[:, :, :2] + prediction[:, :, 8:10] * self.variances[0] * priors[:, :, 2:],
139
+ ), dim=-1)
140
+ return landms
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .multibox_loss import MultiBoxLoss
2
+
3
+ __all__ = ['MultiBoxLoss']
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/modules/multibox_loss.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from ...utils.box_utils import match, log_sum_exp
5
+
6
+
7
+ class MultiBoxLoss(nn.Module):
8
+ """SSD Weighted Loss Function
9
+ Compute Targets:
10
+ 1) Produce Confidence Target Indices by matching ground truth boxes
11
+ with (default) 'priorboxes' that have jaccard index > threshold parameter
12
+ (default threshold: 0.5).
13
+ 2) Produce localization target by 'encoding' variance into offsets of ground
14
+ truth boxes and their matched 'priorboxes'.
15
+ 3) Hard negative mining to filter the excessive number of negative examples
16
+ that comes with using a large number of default bounding boxes.
17
+ (default negative:positive ratio 3:1)
18
+ Objective Loss:
19
+ L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
20
+ Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
21
+ weighted by α which is set to 1 by cross val.
22
+ Args:
23
+ c: class confidences,
24
+ l: predicted boxes,
25
+ g: ground truth boxes
26
+ N: number of matched default boxes
27
+ See: https://arxiv.org/pdf/1512.02325.pdf for more details.
28
+ """
29
+
30
+ def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target):
31
+ super(MultiBoxLoss, self).__init__()
32
+ self.num_classes = num_classes
33
+ self.threshold = overlap_thresh
34
+ self.background_label = bkg_label
35
+ self.encode_target = encode_target
36
+ self.use_prior_for_matching = prior_for_matching
37
+ self.do_neg_mining = neg_mining
38
+ self.negpos_ratio = neg_pos
39
+ self.neg_overlap = neg_overlap
40
+
41
+
42
+ def forward(self, predictions, priorbox, targets):
43
+ """Multibox Loss
44
+ Args:
45
+ predictions (tuple): A tuple containing loc preds, conf preds,
46
+ and prior boxes from SSD net.
47
+ conf shape: torch.size(batch_size,num_priors,num_classes)
48
+ loc shape: torch.size(batch_size,num_priors,4)
49
+ priors shape: torch.size(num_priors,4)
50
+
51
+ ground_truth (tensor): Ground truth boxes and labels for a batch,
52
+ shape: [batch_size,num_objs,5] (last idx is the label).
53
+ """
54
+
55
+ loc_data, conf_data, landm_data, aggs, thetas = predictions
56
+ num = loc_data.size(0)
57
+ num_priors = (priorbox.priors.size(0))
58
+
59
+ if aggs is not None:
60
+ stacked_target = torch.stack(targets, dim=0).squeeze(1)
61
+
62
+ pos_idx = stacked_target[:, -1] > 0
63
+ agg_ldmk = aggs[:, 6:][pos_idx]
64
+ tgt_ldmk = stacked_target[:, 4:14][pos_idx]
65
+ agg_loss_landm = F.smooth_l1_loss(agg_ldmk, tgt_ldmk, reduction='sum') / len(tgt_ldmk)
66
+
67
+ pos_idx = stacked_target[:, -1] != 0
68
+ agg_bbox = aggs[:, :4][pos_idx]
69
+ tgt_bbox = stacked_target[:, :4][pos_idx]
70
+ agg_loss_box = F.smooth_l1_loss(agg_bbox, tgt_bbox, reduction='sum') / len(tgt_bbox)
71
+
72
+ agg_cls = aggs[:, 4:6]
73
+ tgt_cls = (stacked_target[:, -1] > 0).long()
74
+ agg_loss_cls = F.cross_entropy(agg_cls, tgt_cls, reduction='sum') / len(tgt_cls)
75
+ aux_loss_dict = {
76
+ 'agg_loss_landm': agg_loss_landm,
77
+ 'agg_loss_box': agg_loss_box,
78
+ 'agg_loss_cls': agg_loss_cls
79
+ }
80
+ else:
81
+ aux_loss_dict = None
82
+
83
+ # match priors (default boxes) and ground truth boxes
84
+ loc_t = torch.Tensor(num, num_priors, 4)
85
+ landm_t = torch.Tensor(num, num_priors, 10)
86
+ conf_t = torch.LongTensor(num, num_priors)
87
+ for idx in range(num):
88
+ truths = targets[idx][:, :4].data
89
+ labels = targets[idx][:, -1].data
90
+ landms = targets[idx][:, 4:14].data
91
+ match(self.threshold, truths, priorbox, labels, landms, loc_t, conf_t, landm_t, idx)
92
+
93
+ loc_t = loc_t.cuda()
94
+ conf_t = conf_t.cuda()
95
+ landm_t = landm_t.cuda()
96
+ zeros = torch.tensor(0).cuda()
97
+ # landm Loss (Smooth L1)
98
+ # Shape: [batch,num_priors,10]
99
+ pos1 = conf_t > zeros
100
+ num_pos_landm = pos1.long().sum(1, keepdim=True)
101
+ N1 = max(num_pos_landm.data.sum().float(), 1)
102
+ pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data)
103
+ landm_p = landm_data[pos_idx1].view(-1, 10)
104
+ landm_t = landm_t[pos_idx1].view(-1, 10)
105
+ loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum')
106
+
107
+
108
+ pos = conf_t != zeros
109
+ conf_t[pos] = 1
110
+
111
+ # Localization Loss (Smooth L1)
112
+ # Shape: [batch,num_priors,4]
113
+ pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
114
+ loc_p = loc_data[pos_idx].view(-1, 4)
115
+ loc_t = loc_t[pos_idx].view(-1, 4)
116
+ loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
117
+
118
+ # Compute max conf across batch for hard negative mining
119
+ batch_conf = conf_data.view(-1, self.num_classes)
120
+ loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
121
+
122
+ # Hard Negative Mining
123
+ loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now
124
+ loss_c = loss_c.view(num, -1)
125
+ _, loss_idx = loss_c.sort(1, descending=True)
126
+ _, idx_rank = loss_idx.sort(1)
127
+ num_pos = pos.long().sum(1, keepdim=True)
128
+ num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
129
+ neg = idx_rank < num_neg.expand_as(idx_rank)
130
+
131
+ # Confidence Loss Including Positive and Negative Examples
132
+ pos_idx = pos.unsqueeze(2).expand_as(conf_data)
133
+ neg_idx = neg.unsqueeze(2).expand_as(conf_data)
134
+ conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes)
135
+ targets_weighted = conf_t[(pos+neg).gt(0)]
136
+ loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')
137
+
138
+ # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
139
+ N = max(num_pos.data.sum().float(), 1)
140
+ loss_l /= N
141
+ loss_c /= N
142
+ loss_landm /= N1
143
+
144
+ return loss_l, loss_c, loss_landm, aux_loss_dict
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/models/__init__.py ADDED
File without changes
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/models/net.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.models._utils as _utils
5
+ import torchvision.models as models
6
+ import torch.nn.functional as F
7
+ from torch.autograd import Variable
8
+
9
+ def conv_bn(inp, oup, stride = 1, leaky = 0):
10
+ return nn.Sequential(
11
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
12
+ nn.BatchNorm2d(oup),
13
+ nn.LeakyReLU(negative_slope=leaky, inplace=True)
14
+ )
15
+
16
+ def conv_bn_no_relu(inp, oup, stride):
17
+ return nn.Sequential(
18
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
19
+ nn.BatchNorm2d(oup),
20
+ )
21
+
22
+ def conv_bn1X1(inp, oup, stride, leaky=0):
23
+ return nn.Sequential(
24
+ nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
25
+ nn.BatchNorm2d(oup),
26
+ nn.LeakyReLU(negative_slope=leaky, inplace=True)
27
+ )
28
+
29
+ def conv_dw(inp, oup, stride, leaky=0.1):
30
+ return nn.Sequential(
31
+ nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
32
+ nn.BatchNorm2d(inp),
33
+ nn.LeakyReLU(negative_slope= leaky,inplace=True),
34
+
35
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
36
+ nn.BatchNorm2d(oup),
37
+ nn.LeakyReLU(negative_slope= leaky,inplace=True),
38
+ )
39
+
40
+ class SSH(nn.Module):
41
+ def __init__(self, in_channel, out_channel):
42
+ super(SSH, self).__init__()
43
+ assert out_channel % 4 == 0
44
+ leaky = 0
45
+ if (out_channel <= 64):
46
+ leaky = 0.1
47
+ self.conv3X3 = conv_bn_no_relu(in_channel, out_channel//2, stride=1)
48
+
49
+ self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1, leaky = leaky)
50
+ self.conv5X5_2 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
51
+
52
+ self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1, leaky = leaky)
53
+ self.conv7x7_3 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
54
+
55
+ def forward(self, input):
56
+ conv3X3 = self.conv3X3(input)
57
+
58
+ conv5X5_1 = self.conv5X5_1(input)
59
+ conv5X5 = self.conv5X5_2(conv5X5_1)
60
+
61
+ conv7X7_2 = self.conv7X7_2(conv5X5_1)
62
+ conv7X7 = self.conv7x7_3(conv7X7_2)
63
+
64
+ out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
65
+ out = F.relu(out)
66
+ return out
67
+
68
+ class FPN(nn.Module):
69
+ def __init__(self,in_channels_list,out_channels):
70
+ super(FPN,self).__init__()
71
+ leaky = 0
72
+ if (out_channels <= 64):
73
+ leaky = 0.1
74
+ self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1, leaky = leaky)
75
+ self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1, leaky = leaky)
76
+ self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1, leaky = leaky)
77
+
78
+ self.merge1 = conv_bn(out_channels, out_channels, leaky = leaky)
79
+ self.merge2 = conv_bn(out_channels, out_channels, leaky = leaky)
80
+
81
+ def forward(self, input):
82
+ # names = list(input.keys())
83
+ input = list(input.values())
84
+
85
+ output1 = self.output1(input[0])
86
+ output2 = self.output2(input[1])
87
+ output3 = self.output3(input[2])
88
+
89
+ up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest")
90
+ output2 = output2 + up3
91
+ output2 = self.merge2(output2)
92
+
93
+ up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest")
94
+ output1 = output1 + up2
95
+ output1 = self.merge1(output1)
96
+
97
+ out = [output1, output2, output3]
98
+ return out
99
+
100
+
101
+
102
+ class MobileNetV1(nn.Module):
103
+ def __init__(self):
104
+ super(MobileNetV1, self).__init__()
105
+ self.stage1 = nn.Sequential(
106
+ conv_bn(3, 8, 2, leaky = 0.1), # 3
107
+ conv_dw(8, 16, 1), # 7
108
+ conv_dw(16, 32, 2), # 11
109
+ conv_dw(32, 32, 1), # 19
110
+ conv_dw(32, 64, 2), # 27
111
+ conv_dw(64, 64, 1), # 43
112
+ )
113
+ self.stage2 = nn.Sequential(
114
+ conv_dw(64, 128, 2), # 43 + 16 = 59
115
+ conv_dw(128, 128, 1), # 59 + 32 = 91
116
+ conv_dw(128, 128, 1), # 91 + 32 = 123
117
+ conv_dw(128, 128, 1), # 123 + 32 = 155
118
+ conv_dw(128, 128, 1), # 155 + 32 = 187
119
+ conv_dw(128, 128, 1), # 187 + 32 = 219
120
+ )
121
+ self.stage3 = nn.Sequential(
122
+ conv_dw(128, 256, 2), # 219 +3 2 = 241
123
+ conv_dw(256, 256, 1), # 241 + 64 = 301
124
+ )
125
+
126
+ def forward(self, x):
127
+ x = self.stage1(x)
128
+ x = self.stage2(x)
129
+ x = self.stage3(x)
130
+ return x
131
+
132
+
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/models/retinaface.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models._utils as _utils
4
+ import torch.nn.functional as F
5
+ from .net import MobileNetV1 as MobileNetV1
6
+ from .net import FPN as FPN
7
+ from .net import SSH as SSH
8
+
9
+ from timm.models import mlp_mixer
10
+
11
+ class ClassHead(nn.Module):
12
+ def __init__(self,inchannels=512,num_anchors=3):
13
+ super(ClassHead,self).__init__()
14
+ self.num_anchors = num_anchors
15
+ self.conv1x1 = nn.Conv2d(inchannels,self.num_anchors*2,kernel_size=(1,1),stride=1,padding=0)
16
+
17
+ def forward(self,x):
18
+ out = self.conv1x1(x)
19
+ out = out.permute(0,2,3,1).contiguous()
20
+
21
+ return out.view(out.shape[0], -1, 2)
22
+
23
+ class BboxHead(nn.Module):
24
+ def __init__(self,inchannels=512,num_anchors=3):
25
+ super(BboxHead,self).__init__()
26
+ self.conv1x1 = nn.Conv2d(inchannels,num_anchors*4,kernel_size=(1,1),stride=1,padding=0)
27
+
28
+ def forward(self,x):
29
+ out = self.conv1x1(x)
30
+ out = out.permute(0,2,3,1).contiguous()
31
+
32
+ return out.view(out.shape[0], -1, 4)
33
+
34
+ class LandmarkHead(nn.Module):
35
+ def __init__(self,inchannels=512,num_anchors=3):
36
+ super(LandmarkHead,self).__init__()
37
+ self.conv1x1 = nn.Conv2d(inchannels,num_anchors*10,kernel_size=(1,1),stride=1,padding=0)
38
+
39
+ def forward(self,x):
40
+ out = self.conv1x1(x)
41
+ out = out.permute(0,2,3,1).contiguous()
42
+
43
+ return out.view(out.shape[0], -1, 10)
44
+
45
+ class RetinaFace(nn.Module):
46
+ def __init__(self, cfg = None, phase = 'train'):
47
+ """
48
+ :param cfg: Network related settings.
49
+ :param phase: train or test.
50
+ """
51
+ super(RetinaFace,self).__init__()
52
+ self.phase = phase
53
+ backbone = None
54
+ if cfg['name'] == 'mobilenet0.25':
55
+ backbone = MobileNetV1()
56
+ # if cfg['pretrain']:
57
+ # checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu'))
58
+ # from collections import OrderedDict
59
+ # new_state_dict = OrderedDict()
60
+ # for k, v in checkpoint['state_dict'].items():
61
+ # name = k[7:] # remove module.
62
+ # new_state_dict[name] = v
63
+ # load params
64
+ # backbone.load_state_dict(new_state_dict)
65
+ elif cfg['name'] == 'Resnet50':
66
+ import torchvision.models as models
67
+ backbone = models.resnet50(pretrained=cfg['pretrain'])
68
+
69
+ self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers'])
70
+ in_channels_stage2 = cfg['in_channel']
71
+ in_channels_list = [
72
+ in_channels_stage2 * 2,
73
+ in_channels_stage2 * 4,
74
+ in_channels_stage2 * 8,
75
+ ]
76
+ out_channels = cfg['out_channel']
77
+ self.fpn = FPN(in_channels_list,out_channels)
78
+ self.ssh1 = SSH(out_channels, out_channels)
79
+ self.ssh2 = SSH(out_channels, out_channels)
80
+ self.ssh3 = SSH(out_channels, out_channels)
81
+
82
+ self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
83
+ self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
84
+ self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
85
+
86
+ def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2):
87
+ classhead = nn.ModuleList()
88
+ for i in range(fpn_num):
89
+ classhead.append(ClassHead(inchannels,anchor_num))
90
+ return classhead
91
+
92
+ def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2):
93
+ bboxhead = nn.ModuleList()
94
+ for i in range(fpn_num):
95
+ bboxhead.append(BboxHead(inchannels,anchor_num))
96
+ return bboxhead
97
+
98
+ def _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2):
99
+ landmarkhead = nn.ModuleList()
100
+ for i in range(fpn_num):
101
+ landmarkhead.append(LandmarkHead(inchannels,anchor_num))
102
+ return landmarkhead
103
+
104
+ def forward(self, inputs, priorbox=None):
105
+ out = self.body(inputs)
106
+
107
+ # FPN
108
+ fpn = self.fpn(out)
109
+
110
+ # SSH
111
+ feature1 = self.ssh1(fpn[0])
112
+ feature2 = self.ssh2(fpn[1])
113
+ feature3 = self.ssh3(fpn[2])
114
+ features = [feature1, feature2, feature3]
115
+
116
+ bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
117
+ classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)],dim=1)
118
+ ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1)
119
+ if self.phase == 'train':
120
+ output = (bbox_regressions, classifications, ldm_regressions)
121
+ else:
122
+ output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
123
+ return output
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/preprocessor.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ class Preprocessor():
5
+
6
+ def __init__(self, output_size=160, padding=0.0, padding_val='zero'):
7
+ self.output_size = output_size
8
+ self.padding = padding
9
+ self.padding_val = padding_val
10
+
11
+ def preprocess_batched(self, imgs, padding_ratio_override=None):
12
+
13
+ # check img is of float
14
+ if imgs.dtype == torch.float32:
15
+ if self.padding_val == 'zero':
16
+ padding_val = -1.0
17
+ elif self.padding_val == 'mean':
18
+ padding_val = imgs.mean()
19
+ else:
20
+ raise ValueError('padding_val must be "zero" or "mean"')
21
+ elif imgs.dtype == torch.uint8:
22
+ if self.padding_val == 'zero':
23
+ padding_val = 0
24
+ elif self.padding_val == 'mean':
25
+ padding_val = imgs.mean()
26
+ else:
27
+ raise ValueError('padding_val must be "zero" or "mean"')
28
+ else:
29
+ raise ValueError('imgs.dtype must be torch.float32 or torch.uint8')
30
+
31
+ square_imgs = self.make_square_img_batched(imgs, padding_val=padding_val)
32
+
33
+ if padding_ratio_override is not None:
34
+ padding = padding_ratio_override
35
+ else:
36
+ padding = self.padding
37
+ padded_imgs = self.make_padded_img_batched(square_imgs, padding=padding, padding_val=padding_val)
38
+
39
+ size=(self.output_size, self.output_size)
40
+ if imgs.dtype == torch.float32:
41
+ resized_imgs = F.interpolate(padded_imgs, size=size, mode='bilinear', align_corners=True)
42
+ elif imgs.dtype == torch.uint8:
43
+ padded_imgs = padded_imgs.to(torch.float32)
44
+ resized_imgs = F.interpolate(padded_imgs, size=size, mode='bilinear', align_corners=True)
45
+ resized_imgs = torch.clip(resized_imgs, 0, 255)
46
+ resized_imgs = resized_imgs.to(torch.uint8)
47
+ else:
48
+ raise ValueError('imgs.dtype must be torch.float32 or torch.uint8')
49
+ return resized_imgs
50
+
51
+
52
+ def make_square_img_batched(self, imgs, padding_val):
53
+ assert imgs.ndim == 4
54
+ # squarify the image
55
+ h, w = imgs.shape[2:]
56
+ if h > w:
57
+ diff = (h - w)
58
+ pad_left = diff // 2
59
+ pad_right = diff - pad_left
60
+ imgs = F.pad(imgs, (pad_left, pad_right, 0, 0), value=padding_val)
61
+ elif w > h:
62
+ diff = (w - h)
63
+ pad_top = diff // 2
64
+ pad_bottom = diff - pad_top
65
+ imgs = F.pad(imgs, (0, 0, pad_top, pad_bottom), value=padding_val)
66
+ assert imgs.shape[2] == imgs.shape[3]
67
+ return imgs
68
+
69
+
70
+ def make_padded_img_batched(self, imgs, padding, padding_val):
71
+ if padding == 0:
72
+ return imgs
73
+ assert imgs.ndim == 4
74
+
75
+
76
+ # pad the image
77
+ h, w = imgs.shape[2:]
78
+ pad_h = int(h * padding)
79
+ pad_w = int(w * padding)
80
+ imgs = F.pad(imgs, (pad_w, pad_w, pad_h, pad_h), value=padding_val)
81
+ return imgs
82
+
83
+
84
+ def __call__(self, input, padding_ratio_override=None):
85
+ if input.ndim == 3:
86
+ assert input.shape[0] == 3
87
+ batch_input = input.unsqueeze(0)
88
+ return self.preprocess_batched(batch_input, padding_ratio_override=padding_ratio_override)[0]
89
+ elif input.ndim == 4:
90
+ assert input.shape[1] == 3
91
+ return self.preprocess_batched(input, padding_ratio_override=padding_ratio_override)
92
+ else:
93
+ raise ValueError(f'Invalid input shape: {input.shape}')
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/utils/box_utils.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def point_form(boxes):
6
+ """ Convert prior_boxes to (xmin, ymin, xmax, ymax)
7
+ representation for comparison to point form ground truth data.
8
+ Args:
9
+ boxes: (tensor) center-size default boxes from priorbox layers.
10
+ Return:
11
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
12
+ """
13
+ return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin
14
+ boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax
15
+
16
+
17
+ def center_size(boxes):
18
+ """ Convert prior_boxes to (cx, cy, w, h)
19
+ representation for comparison to center-size form ground truth data.
20
+ Args:
21
+ boxes: (tensor) point_form boxes
22
+ Return:
23
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
24
+ """
25
+ return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy
26
+ boxes[:, 2:] - boxes[:, :2], 1) # w, h
27
+
28
+
29
+ def intersect(box_a, box_b):
30
+ """ We resize both tensors to [A,B,2] without new malloc:
31
+ [A,2] -> [A,1,2] -> [A,B,2]
32
+ [B,2] -> [1,B,2] -> [A,B,2]
33
+ Then we compute the area of intersect between box_a and box_b.
34
+ Args:
35
+ box_a: (tensor) bounding boxes, Shape: [A,4].
36
+ box_b: (tensor) bounding boxes, Shape: [B,4].
37
+ Return:
38
+ (tensor) intersection area, Shape: [A,B].
39
+ """
40
+ A = box_a.size(0)
41
+ B = box_b.size(0)
42
+ max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
43
+ box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
44
+ min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
45
+ box_b[:, :2].unsqueeze(0).expand(A, B, 2))
46
+ inter = torch.clamp((max_xy - min_xy), min=0)
47
+ return inter[:, :, 0] * inter[:, :, 1]
48
+
49
+
50
+ def jaccard(box_a, box_b):
51
+ """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
52
+ is simply the intersection over union of two boxes. Here we operate on
53
+ ground truth boxes and default boxes.
54
+ E.g.:
55
+ A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
56
+ Args:
57
+ box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
58
+ box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
59
+ Return:
60
+ jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
61
+ """
62
+ inter = intersect(box_a, box_b)
63
+ area_a = ((box_a[:, 2]-box_a[:, 0]) *
64
+ (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
65
+ area_b = ((box_b[:, 2]-box_b[:, 0]) *
66
+ (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
67
+ union = area_a + area_b - inter
68
+ return inter / union # [A,B]
69
+
70
+
71
+ def matrix_iou(a, b):
72
+ """
73
+ return iou of a and b, numpy version for data augenmentation
74
+ """
75
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
76
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
77
+
78
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
79
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
80
+ area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
81
+ return area_i / (area_a[:, np.newaxis] + area_b - area_i)
82
+
83
+
84
+ def matrix_iof(a, b):
85
+ """
86
+ return iof of a and b, numpy version for data augenmentation
87
+ """
88
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
89
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
90
+
91
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
92
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
93
+ return area_i / np.maximum(area_a[:, np.newaxis], 1)
94
+
95
+
96
+ def match(threshold, truths, priorbox, labels, landms, loc_t, conf_t, landm_t, idx):
97
+ """Match each prior box with the ground truth box of the highest jaccard
98
+ overlap, encode the bounding boxes, then return the matched indices
99
+ corresponding to both confidence and location preds.
100
+ Args:
101
+ threshold: (float) The overlap threshold used when mathing boxes.
102
+ truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
103
+ priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
104
+ variances: (tensor) Variances corresponding to each prior coord,
105
+ Shape: [num_priors, 4].
106
+ labels: (tensor) All the class labels for the image, Shape: [num_obj].
107
+ landms: (tensor) Ground truth landms, Shape [num_obj, 10].
108
+ loc_t: (tensor) Tensor to be filled w/ endcoded location targets.
109
+ conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
110
+ landm_t: (tensor) Tensor to be filled w/ endcoded landm targets.
111
+ idx: (int) current batch index
112
+ Return:
113
+ The matched indices corresponding to 1)location 2)confidence 3)landm preds.
114
+ """
115
+
116
+
117
+ # jaccard index
118
+ overlaps = jaccard(
119
+ truths,
120
+ point_form(priorbox.priors)
121
+ )
122
+ # (Bipartite Matching)
123
+ # [1,num_objects] best prior for each ground truth
124
+ best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
125
+
126
+ # ignore hard gt
127
+ valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
128
+ best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
129
+ if best_prior_idx_filter.shape[0] <= 0:
130
+ loc_t[idx] = 0
131
+ conf_t[idx] = 0
132
+ return
133
+
134
+ # [1,num_priors] best ground truth for each prior
135
+ best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
136
+ best_truth_idx.squeeze_(0)
137
+ best_truth_overlap.squeeze_(0)
138
+ best_prior_idx.squeeze_(1)
139
+ best_prior_idx_filter.squeeze_(1)
140
+ best_prior_overlap.squeeze_(1)
141
+ best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
142
+ # TODO refactor: index best_prior_idx with long tensor
143
+ # ensure every gt matches with its prior of max overlap
144
+ for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes
145
+ best_truth_idx[best_prior_idx[j]] = j
146
+ matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来
147
+ conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来
148
+ conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本
149
+ loc = priorbox.encode(matches)
150
+
151
+ matches_landm = landms[best_truth_idx]
152
+ landm = priorbox.encode_landm(matches_landm)
153
+ loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
154
+ conf_t[idx] = conf # [num_priors] top class label for each prior
155
+ landm_t[idx] = landm
156
+
157
+
158
+
159
+ def log_sum_exp(x):
160
+ """Utility function for computing log_sum_exp while determining
161
+ This will be used to determine unaveraged confidence loss across
162
+ all examples in a batch.
163
+ Args:
164
+ x (Variable(tensor)): conf_preds from conf layers
165
+ """
166
+ x_max = x.data.max()
167
+ return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max
168
+
169
+
170
+ # Original author: Francisco Massa:
171
+ # https://github.com/fmassa/object-detection.torch
172
+ # Ported to PyTorch by Max deGroot (02/01/2017)
173
+ def nms(boxes, scores, overlap=0.5, top_k=200):
174
+ """Apply non-maximum suppression at test time to avoid detecting too many
175
+ overlapping bounding boxes for a given object.
176
+ Args:
177
+ boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
178
+ scores: (tensor) The class predscores for the img, Shape:[num_priors].
179
+ overlap: (float) The overlap thresh for suppressing unnecessary boxes.
180
+ top_k: (int) The Maximum number of box preds to consider.
181
+ Return:
182
+ The indices of the kept boxes with respect to num_priors.
183
+ """
184
+
185
+ keep = torch.Tensor(scores.size(0)).fill_(0).long()
186
+ if boxes.numel() == 0:
187
+ return keep
188
+ x1 = boxes[:, 0]
189
+ y1 = boxes[:, 1]
190
+ x2 = boxes[:, 2]
191
+ y2 = boxes[:, 3]
192
+ area = torch.mul(x2 - x1, y2 - y1)
193
+ v, idx = scores.sort(0) # sort in ascending order
194
+ # I = I[v >= 0.01]
195
+ idx = idx[-top_k:] # indices of the top-k largest vals
196
+ xx1 = boxes.new()
197
+ yy1 = boxes.new()
198
+ xx2 = boxes.new()
199
+ yy2 = boxes.new()
200
+ w = boxes.new()
201
+ h = boxes.new()
202
+
203
+ # keep = torch.Tensor()
204
+ count = 0
205
+ while idx.numel() > 0:
206
+ i = idx[-1] # index of current largest val
207
+ # keep.append(i)
208
+ keep[count] = i
209
+ count += 1
210
+ if idx.size(0) == 1:
211
+ break
212
+ idx = idx[:-1] # remove kept element from view
213
+ # load bboxes of next highest vals
214
+ torch.index_select(x1, 0, idx, out=xx1)
215
+ torch.index_select(y1, 0, idx, out=yy1)
216
+ torch.index_select(x2, 0, idx, out=xx2)
217
+ torch.index_select(y2, 0, idx, out=yy2)
218
+ # store element-wise max with next highest score
219
+ xx1 = torch.clamp(xx1, min=x1[i])
220
+ yy1 = torch.clamp(yy1, min=y1[i])
221
+ xx2 = torch.clamp(xx2, max=x2[i])
222
+ yy2 = torch.clamp(yy2, max=y2[i])
223
+ w.resize_as_(xx2)
224
+ h.resize_as_(yy2)
225
+ w = xx2 - xx1
226
+ h = yy2 - yy1
227
+ # check sizes of xx1 and xx2.. after each iteration
228
+ w = torch.clamp(w, min=0.0)
229
+ h = torch.clamp(h, min=0.0)
230
+ inter = w*h
231
+ # IoU = i / (area(a) + area(b) - i)
232
+ rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
233
+ union = (rem_areas - inter) + area[i]
234
+ IoU = inter/union # store result in iou
235
+ # keep only elements with an IoU <= overlap
236
+ idx = idx[IoU.le(overlap)]
237
+ return keep, count
238
+
239
+
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/utils/model_utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def remove_prefix(state_dict, prefix):
4
+ ''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
5
+ print('remove prefix \'{}\''.format(prefix))
6
+ f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
7
+ return {f(key): value for key, value in state_dict.items()}
8
+
9
+ def check_keys(model, pretrained_state_dict):
10
+ ckpt_keys = set(pretrained_state_dict.keys())
11
+ model_keys = set(model.state_dict().keys())
12
+ used_pretrained_keys = model_keys & ckpt_keys
13
+ unused_pretrained_keys = ckpt_keys - model_keys
14
+ missing_keys = model_keys - ckpt_keys
15
+ print('Missing keys:{}'.format(len(missing_keys)))
16
+ print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
17
+ print('Used keys:{}'.format(len(used_pretrained_keys)))
18
+ assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
19
+ return True
20
+
21
+ def load_model(model, pretrained_path, load_to_cpu):
22
+ print('Loading pretrained model from {}'.format(pretrained_path))
23
+ if load_to_cpu:
24
+ pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
25
+ else:
26
+ device = torch.cuda.current_device()
27
+ pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
28
+ if "state_dict" in pretrained_dict.keys():
29
+ pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
30
+ else:
31
+ pretrained_dict = remove_prefix(pretrained_dict, 'module.')
32
+ check_keys(model, pretrained_dict)
33
+ model.load_state_dict(pretrained_dict, strict=False)
34
+ return model
35
+
36
+
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface_pipeline.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ from .retinaface.utils.model_utils import load_model
5
+ from .retinaface.layers.functions.prior_box import PriorBox
6
+ from .retinaface.models.retinaface import RetinaFace
7
+ import torch.nn.functional as F
8
+
9
+
10
+ cfg_mnet = {
11
+ 'name': 'mobilenet0.25',
12
+ 'gpu_train': True,
13
+ 'ngpu': 1,
14
+ 'epoch': 250,
15
+ 'decay1': 190,
16
+ 'decay2': 220,
17
+ # 'image_size': 640,
18
+ 'pretrain': True,
19
+ 'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3},
20
+ 'in_channel': 32,
21
+ 'out_channel': 64
22
+ }
23
+
24
+
25
+ cfg_re50 = {
26
+ 'name': 'Resnet50',
27
+ 'gpu_train': True,
28
+ 'ngpu': 4,
29
+ 'epoch': 100,
30
+ 'decay1': 70,
31
+ 'decay2': 90,
32
+ # 'image_size': 840,
33
+ 'pretrain': True,
34
+ 'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3},
35
+ 'in_channel': 256,
36
+ 'out_channel': 256
37
+ }
38
+
39
+
40
+ def load_retinface_model(network='resnet50', trained_model_path=''):
41
+ cfg = None
42
+ if network == "mobile0.25":
43
+ cfg = cfg_mnet
44
+ elif network == "resnet50":
45
+ cfg = cfg_re50
46
+ # net and model
47
+ net = RetinaFace(cfg=cfg, phase = 'test')
48
+ net = load_model(net, trained_model_path, True)
49
+ net.eval()
50
+ # freeze grad
51
+ for param in net.parameters():
52
+ param.requires_grad = False
53
+
54
+ return net
55
+
56
+
57
+ class RetinaFacePipeline(torch.nn.Module):
58
+
59
+ def __init__(self, net, priorbox, input_size, device='cuda'):
60
+ super().__init__()
61
+ self.net = net
62
+ self.priorbox = priorbox
63
+ self.input_size = input_size
64
+ self.output_size = 112
65
+ self.device = device
66
+
67
+
68
+ def normalize(self, image):
69
+ image = image / 255.
70
+ image = (image - 0.5) / 0.5
71
+ return image
72
+
73
+ def unnormalize(self, image):
74
+ image = image * 0.5 + 0.5
75
+ image = image * 255.
76
+ return image
77
+
78
+ def normalize_for_net(self, bgr_image_0_255):
79
+ # bgr_image = cv2.imread(image_path, cv2.IMREAD_COLOR)
80
+ return bgr_image_0_255 - torch.tensor([104, 117, 123])[None, :, None, None].to(self.device)
81
+
82
+ def prealign_preprocess(self, images, value=0.0):
83
+ # pad to input_size
84
+ assert isinstance(images, torch.Tensor)
85
+ assert images.ndim == 4 or images.ndim == 3
86
+ input_size = self.input_size
87
+
88
+ data_width = images.shape[-1]
89
+ data_height = images.shape[-2]
90
+ if data_width > input_size or data_height > input_size:
91
+ # image is biggert than the input size
92
+ # resize such that the larger side becomes the input_size without changing the aspect ratio
93
+ if data_width > data_height:
94
+ scale = input_size / data_width
95
+ else:
96
+ scale = input_size / data_height
97
+ if images.ndim == 4:
98
+ images = F.interpolate(input=images, scale_factor=scale,
99
+ mode='bilinear', align_corners=False)
100
+ else:
101
+ images = F.interpolate(input=images.unsqueeze(0), scale_factor=scale,
102
+ mode='bilinear', align_corners=False).squeeze(0)
103
+
104
+ data_width = images.shape[-1]
105
+ data_height = images.shape[-2]
106
+ padding_width1 = (input_size - data_width) // 2
107
+ padding_width2 = (input_size - data_width) - padding_width1
108
+ padding_height1 = (input_size - data_height) // 2
109
+ padding_height2 = (input_size - data_height) - padding_height1
110
+
111
+ result = torch.nn.functional.pad(input=images,
112
+ pad=(padding_width1, padding_width2,
113
+ padding_height1, padding_height2),
114
+ value=value)
115
+ assert result.shape[-1] == input_size
116
+ assert result.shape[-2] == input_size
117
+ return result
118
+
119
+ def forward(self, rgb_images):
120
+
121
+ # cv2.imwrite('/mckim/temp/temp.jpg', self.unnormalize(rgb_images[0]).cpu().numpy().transpose(1,2,0))
122
+
123
+ assert rgb_images.shape[1] == 3
124
+ assert rgb_images.ndim == 4
125
+ assert isinstance(rgb_images, torch.Tensor)
126
+ assert self.priorbox.image_size == rgb_images.shape[2:]
127
+ rgb_images = rgb_images.to(self.device)
128
+
129
+ # make image into BGR
130
+ bgr_images = rgb_images.flip(1)
131
+ input_img = self.normalize_for_net(self.unnormalize(bgr_images))
132
+ batch_loc, batch_conf, batch_landms = self.net(input_img)
133
+ batch_loc = torch.split(batch_loc, 1, dim=0)
134
+ batch_conf = torch.split(batch_conf, 1, dim=0)
135
+ batch_landms = torch.split(batch_landms, 1, dim=0)
136
+
137
+ all_ldmks = []
138
+ for loc, conf, landms, in zip(batch_loc, batch_conf, batch_landms):
139
+ dets = postprocess(self.priorbox, loc, conf, landms, confidence_threshold=0.0, nms_threshold=0.4)
140
+ bbox, score, ldmks = parse_one_det_result(dets)
141
+ ldmks = ldmks / np.array( [self.priorbox.image_size[0], self.priorbox.image_size[1]] * 5)
142
+ all_ldmks.append(ldmks)
143
+ all_ldmks = torch.from_numpy(np.array(all_ldmks)).to(self.device).float()
144
+ return all_ldmks
145
+
146
+
147
+ def postprocess(priorbox, loc, conf, landms, confidence_threshold, nms_threshold):
148
+
149
+ device = loc.device
150
+ im_height, im_width = priorbox.image_size
151
+
152
+ scale = torch.Tensor([im_width, im_height, im_width, im_height])
153
+ scale = scale.to(device)
154
+
155
+ boxes = priorbox.decode(loc.data.squeeze(0))
156
+ boxes = boxes * scale
157
+ boxes = boxes.cpu().numpy()
158
+ scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
159
+ landms = priorbox.decode_landm(landms.data.squeeze(0))
160
+ scale1 = torch.Tensor([im_width, im_height, im_width, im_height,
161
+ im_width, im_height, im_width, im_height,
162
+ im_width, im_height])
163
+ scale1 = scale1.to(device)
164
+ landms = landms * scale1
165
+ landms = landms.cpu().numpy()
166
+
167
+ # ignore low scores
168
+ inds = np.where(scores > confidence_threshold)[0]
169
+ if len(inds) == 0:
170
+ inds = np.where(scores >= 0)[0]
171
+ boxes = boxes[inds]
172
+ landms = landms[inds]
173
+ scores = scores[inds]
174
+
175
+ # keep top-K before NMS
176
+ order = scores.argsort()[::-1]
177
+ # order = scores.argsort()[::-1][:args.top_k]
178
+ boxes = boxes[order]
179
+ landms = landms[order]
180
+ scores = scores[order]
181
+
182
+ # do NMS
183
+ dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
184
+ keep = py_cpu_nms(dets, nms_threshold)
185
+ # keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
186
+ dets = dets[keep, :]
187
+ landms = landms[keep]
188
+
189
+ # keep top-K faster NMS
190
+ # dets = dets[:args.keep_top_k, :]
191
+ # landms = landms[:args.keep_top_k, :]
192
+
193
+ dets = np.concatenate((dets, landms), axis=1)
194
+ return dets
195
+
196
+
197
+ def py_cpu_nms(dets, thresh):
198
+ """Pure Python NMS baseline."""
199
+ x1 = dets[:, 0]
200
+ y1 = dets[:, 1]
201
+ x2 = dets[:, 2]
202
+ y2 = dets[:, 3]
203
+ scores = dets[:, 4]
204
+
205
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
206
+ order = scores.argsort()[::-1]
207
+
208
+ keep = []
209
+ while order.size > 0:
210
+ i = order[0]
211
+ keep.append(i)
212
+ xx1 = np.maximum(x1[i], x1[order[1:]])
213
+ yy1 = np.maximum(y1[i], y1[order[1:]])
214
+ xx2 = np.minimum(x2[i], x2[order[1:]])
215
+ yy2 = np.minimum(y2[i], y2[order[1:]])
216
+
217
+ w = np.maximum(0.0, xx2 - xx1 + 1)
218
+ h = np.maximum(0.0, yy2 - yy1 + 1)
219
+ inter = w * h
220
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
221
+
222
+ inds = np.where(ovr <= thresh)[0]
223
+ order = order[inds + 1]
224
+
225
+ return keep
226
+
227
+
228
+ def parse_one_det_result(dets):
229
+ dets_sorted = dets[dets[:, 4].argsort()[::-1]]
230
+ result = dets_sorted[0]
231
+ bbox = result[:4]
232
+ score = result[4]
233
+ ldmks = result[5:]
234
+ return bbox, score, ldmks
235
+
236
+
237
+ def load_retinaface_pipeline(network, trained_model_path, input_size, device):
238
+ net = load_retinface_model(network='resnet50', trained_model_path=trained_model_path)
239
+ net = net.to(device)
240
+ priorbox = PriorBox(image_size=(input_size, input_size),
241
+ min_sizes=[[16, 32], [64, 128], [256, 512]],
242
+ steps=[8,16,32], clip=False,
243
+ variances=[0.1, 0.2],
244
+ device=device)
245
+ pipeline = RetinaFacePipeline(net, priorbox, input_size, device=device)
246
+ pipeline.cuda()
247
+ return pipeline
cvlface/research/recognition/code/run_v1/base.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - trainers : configs/default
3
+ - optims : configs/cosine
4
+ - pefts: configs/none
5
+ - models: vit/configs/v1_small
6
+ - classifiers: configs/partial_fc
7
+ - aligners: configs/none
8
+ - dataset: configs/casia
9
+ - data_augs: configs/v7
10
+ - losses: configs/adaface
11
+ - pipelines: configs/train_model_cls
12
+ - evaluations: configs/base
cvlface/research/recognition/code/run_v1/classifiers/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import partial_fc
2
+ from . import fc
3
+
4
+ def get_classifier(classifier_cfg, margin_loss_fn, model_cfg, num_classes, rank, world_size):
5
+
6
+ if margin_loss_fn is None:
7
+ classifier = None
8
+ print("No margin loss function provided, classifier will not be created")
9
+ return classifier
10
+
11
+ if classifier_cfg.name == 'partial_fc':
12
+ classifier = partial_fc.PartialFCClassifier.from_config(classifier_cfg, margin_loss_fn,
13
+ model_cfg, num_classes,
14
+ rank, world_size)
15
+ elif classifier_cfg.name == 'fc':
16
+ classifier = fc.FCClassifier.from_config(classifier_cfg, margin_loss_fn,
17
+ model_cfg, num_classes,
18
+ rank, world_size)
19
+
20
+ else:
21
+ raise ValueError(f"Unknown classifier: {classifier_cfg.name}")
22
+
23
+ if classifier_cfg.start_from:
24
+ classifier.load_state_dict_from_path(classifier_cfg.start_from)
25
+
26
+ if classifier_cfg.freeze:
27
+ for param in classifier.parameters():
28
+ param.requires_grad = False
29
+
30
+ return classifier
31
+
cvlface/research/recognition/code/run_v1/classifiers/base/__init__.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union
3
+ import torch
4
+ from torch import device
5
+ from .utils import get_parameter_device, get_parameter_dtype, save_state_dict_and_config, load_state_dict_from_path
6
+ from general_utils.os_utils import natural_sort
7
+
8
+ class BaseClassifier(torch.nn.Module):
9
+
10
+ def __init__(self, config=None):
11
+ super(BaseClassifier, self).__init__()
12
+ self.config = config
13
+
14
+ @classmethod
15
+ def from_config(cls, classifier_cfg, margin_loss_fn, model_cfg, dataset_cfg, rank, world_size) -> "BaseClassifier":
16
+ raise NotImplementedError('from_config must be implemented in subclass')
17
+
18
+ def forward(self, local_embeddings, local_labels):
19
+ raise NotImplementedError('from_config must be implemented in subclass')
20
+
21
+
22
+ @property
23
+ def device(self) -> device:
24
+ return get_parameter_device(self)
25
+
26
+ @property
27
+ def dtype(self) -> torch.dtype:
28
+ return get_parameter_dtype(self)
29
+
30
+ def num_parameters(self, only_trainable: bool = False) -> int:
31
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
32
+
33
+ def has_trainable_params(self):
34
+ for param in self.parameters():
35
+ if param.requires_grad:
36
+ return True
37
+ return False
38
+
39
+ def save_pretrained(
40
+ self,
41
+ save_dir: Union[str, os.PathLike],
42
+ name: str = 'model.pt',
43
+ rank: int = 0,
44
+ ):
45
+ rank_added_name = os.path.splitext(name)[0] + f'_rank{rank}' + os.path.splitext(name)[1]
46
+ save_path = os.path.join(save_dir, rank_added_name)
47
+ save_state_dict_and_config(self.state_dict(), self.config, save_path)
48
+
49
+
50
+ def load_state_dict_from_path(self, pretrained_model_path):
51
+
52
+ save_dir = os.path.dirname(pretrained_model_path)
53
+ save_name = os.path.basename(pretrained_model_path)
54
+ rank_added_name = os.path.splitext(save_name)[0] + f'_rank{self.rank}' + os.path.splitext(save_name)[1]
55
+ pretrained_model_path = os.path.join(save_dir, rank_added_name)
56
+
57
+ all_partitions = [name for name in os.listdir(save_dir) if '_rank' in name and '.pt' in name]
58
+ all_partitions = natural_sort(all_partitions)
59
+ ckpt_worldsize = len(all_partitions)
60
+
61
+ if self.world_size != ckpt_worldsize:
62
+ # we need to redistribute the partialfc weights
63
+ part_ckpts = [torch.load(os.path.join(save_dir, name), map_location='cpu') for name in all_partitions]
64
+ total_ckpt_num_subjects = sum([ckpt['partial_fc.weight'].shape[0] for ckpt in part_ckpts])
65
+ assert total_ckpt_num_subjects - self.partial_fc.num_classes < 10, \
66
+ (f"total_ckpt_num_subjects: {total_ckpt_num_subjects}, "
67
+ f"self.partial_fc.num_classes: {self.partial_fc.num_classes}"
68
+ f"The number can be slightly different due to the last partition.")
69
+
70
+ combined_weight = torch.cat([ckpt['partial_fc.weight'] for ckpt in part_ckpts], dim=0)
71
+ state_dict = part_ckpts[0]
72
+
73
+ class_start = self.partial_fc.class_start
74
+ num_sample = self.partial_fc.num_local
75
+ sub_center = combined_weight[class_start:class_start + num_sample, :]
76
+ if sub_center.shape[0] != num_sample:
77
+ # append zero
78
+ extra_center = torch.zeros(num_sample - sub_center.shape[0], sub_center.shape[1],
79
+ device=self.device, dtype=self.dtype)
80
+ sub_center = torch.cat([sub_center, extra_center], dim=0)
81
+ state_dict['partial_fc.weight'] = sub_center
82
+
83
+ else:
84
+ state_dict = load_state_dict_from_path(pretrained_model_path)
85
+
86
+ result = self.load_state_dict(state_dict, strict=False)
87
+ print(result)
cvlface/research/recognition/code/run_v1/classifiers/base/utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from typing import List, Optional, Tuple, Union
3
+ import safetensors
4
+ import torch
5
+ from torch import Tensor
6
+ import os
7
+ from pathlib import Path
8
+ from omegaconf import DictConfig, OmegaConf
9
+
10
+
11
+ def get_parameter_device(parameter: torch.nn.Module):
12
+ try:
13
+ parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
14
+ return next(parameters_and_buffers).device
15
+ except StopIteration:
16
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
17
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
18
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
19
+ return tuples
20
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
21
+ first_tuple = next(gen)
22
+ return first_tuple[1].device
23
+
24
+
25
+ def get_parameter_dtype(parameter: torch.nn.Module):
26
+ try:
27
+ params = tuple(parameter.parameters())
28
+ if len(params) > 0:
29
+ return params[0].dtype
30
+
31
+ buffers = tuple(parameter.buffers())
32
+ if len(buffers) > 0:
33
+ return buffers[0].dtype
34
+
35
+ except StopIteration:
36
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
37
+
38
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
39
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
40
+ return tuples
41
+
42
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
43
+ first_tuple = next(gen)
44
+ return first_tuple[1].dtype
45
+
46
+
47
+ def get_parent_directory(save_path: Union[str, os.PathLike]) -> Path:
48
+ path_obj = Path(save_path)
49
+ return path_obj.parent
50
+
51
+ def get_base_name(save_path: Union[str, os.PathLike]) -> str:
52
+ path_obj = Path(save_path)
53
+ return path_obj.name
54
+
55
+ def load_state_dict_from_path(path: Union[str, os.PathLike]):
56
+ # Load a state dict from a path.
57
+ if 'safetensors' in path:
58
+ state_dict = safetensors.torch.load_file(path)
59
+ else:
60
+ state_dict = torch.load(path, map_location="cpu")
61
+ return state_dict
62
+
63
+ def replace_extension(path, new_extension):
64
+ if not new_extension.startswith('.'):
65
+ new_extension = '.' + new_extension
66
+ return os.path.splitext(path)[0] + new_extension
67
+
68
+ def make_config_path(save_path):
69
+ config_path = replace_extension(save_path, '.yaml')
70
+ return config_path
71
+
72
+ def save_config(config, config_path):
73
+ assert isinstance(config, dict) or isinstance(config, DictConfig)
74
+ os.makedirs(get_parent_directory(config_path), exist_ok=True)
75
+ if isinstance(config, dict):
76
+ config = OmegaConf.create(config)
77
+ OmegaConf.save(config, config_path)
78
+
79
+
80
+ def save_state_dict_and_config(state_dict, config, save_path):
81
+ os.makedirs(get_parent_directory(save_path), exist_ok=True)
82
+
83
+ # save config dict
84
+ config_path = make_config_path(save_path)
85
+ save_config(config, config_path)
86
+
87
+ # Save the model
88
+ if 'safetensors' in save_path:
89
+ safetensors.torch.save_file(state_dict, save_path, metadata={"format": "pt"})
90
+ else:
91
+ torch.save(state_dict, save_path)
cvlface/research/recognition/code/run_v1/classifiers/configs/fc.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: 'fc'
2
+ sample_rate: 1.0
3
+ start_from: ''
4
+ freeze: False
cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: 'partial_fc'
2
+ sample_rate: 1.0
3
+ start_from: ''
4
+ freeze: False
cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_freeze.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: 'partial_fc'
2
+ sample_rate: 1.0
3
+ start_from: ''
4
+ freeze: True
cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_sample10.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: 'partial_fc'
2
+ sample_rate: 0.1
3
+ start_from: ''
4
+ freeze: False
cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_sample10_freeze.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: 'partial_fc'
2
+ sample_rate: 0.1
3
+ start_from: ''
4
+ freeze: True
cvlface/research/recognition/code/run_v1/classifiers/fc/__init__.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..base import BaseClassifier, load_state_dict_from_path
2
+ from .fc import FC
3
+ from typing import Union
4
+ import os
5
+
6
+
7
+ class FCClassifier(BaseClassifier):
8
+
9
+ def __init__(self, classifier, config, rank, world_size):
10
+ super(FCClassifier, self).__init__()
11
+ self.classifier = classifier
12
+ self.config = config
13
+ self.rank = rank
14
+ self.world_size = world_size
15
+ self.apply_ddp = True
16
+
17
+ @classmethod
18
+ def from_config(cls, classifier_cfg, margin_loss_fn, model_cfg, num_classes, rank, world_size):
19
+ if classifier_cfg.name == 'fc':
20
+ classifier = FC(
21
+ margin_loss=margin_loss_fn,
22
+ embedding_size=model_cfg.output_dim,
23
+ num_classes=num_classes,
24
+ )
25
+ else:
26
+ raise NotImplementedError
27
+
28
+ model = cls(classifier, classifier_cfg, rank, world_size)
29
+ model.eval()
30
+ return model
31
+
32
+ def forward(self, local_embeddings, local_labels):
33
+ loss = self.classifier(local_embeddings, local_labels)
34
+ return loss
35
+
36
+ def save_pretrained(
37
+ self,
38
+ save_dir: Union[str, os.PathLike],
39
+ name: str = 'classifier.pt',
40
+ rank: int = 0,
41
+ ):
42
+ if rank == 0:
43
+ super().save_pretrained(save_dir, name, rank)
44
+
45
+ def load_state_dict_from_path(self, pretrained_model_path):
46
+ save_dir = os.path.dirname(pretrained_model_path)
47
+ save_name = os.path.basename(pretrained_model_path)
48
+ rank_added_name = os.path.splitext(save_name)[0] + f'_rank0' + os.path.splitext(save_name)[1]
49
+ pretrained_model_path = os.path.join(save_dir, rank_added_name)
50
+
51
+ state_dict = load_state_dict_from_path(pretrained_model_path)
52
+ result = self.load_state_dict(state_dict, strict=False)
53
+ print('classifier loading result', result)
54
+
55
+
cvlface/research/recognition/code/run_v1/classifiers/fc/fc.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+ import torch
3
+ from torch import distributed
4
+ from torch.nn.functional import linear, normalize
5
+ from losses.margin_loss import CombinedMarginLoss
6
+ from losses.adaface import AdaFaceLoss
7
+
8
+
9
+
10
+ class FC(torch.nn.Module):
11
+
12
+ def __init__(
13
+ self,
14
+ margin_loss: Callable,
15
+ embedding_size: int,
16
+ num_classes: int,
17
+ ):
18
+ super(FC, self).__init__()
19
+
20
+ self.cross_entropy = torch.nn.CrossEntropyLoss()
21
+ self.embedding_size = embedding_size
22
+ self.num_classes = num_classes
23
+ self.weight = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_classes, embedding_size)))
24
+
25
+ # margin_loss
26
+ if isinstance(margin_loss, Callable):
27
+ self.margin_softmax = margin_loss
28
+ if isinstance(margin_loss, AdaFaceLoss):
29
+ self.register_buffer('batch_mean', torch.ones(1)*(20))
30
+ self.register_buffer('batch_std', torch.ones(1)*100)
31
+ else:
32
+ raise
33
+
34
+
35
+ def forward(
36
+ self,
37
+ local_embeddings: torch.Tensor,
38
+ local_labels: torch.Tensor,
39
+ ):
40
+
41
+ embeddings = local_embeddings
42
+ labels = local_labels
43
+ weight = self.weight
44
+
45
+ norms = embeddings.norm(p=2, dim=1, keepdim=True).clamp_min(1e-8)
46
+ norm_embeddings = embeddings / norms
47
+
48
+ norm_weight_activated = normalize(weight)
49
+ logits = linear(norm_embeddings, norm_weight_activated)
50
+ logits = logits.clamp(-1, 1)
51
+
52
+ if isinstance(self.margin_softmax, CombinedMarginLoss):
53
+ logits = self.margin_softmax(logits=logits, labels=labels)
54
+ elif isinstance(self.margin_softmax, AdaFaceLoss):
55
+ logits, batch_mean, batch_std = self.margin_softmax(logits=logits, labels=labels, norms=norms,
56
+ batch_mean=self.batch_mean,
57
+ batch_std=self.batch_std)
58
+ self.batch_mean.data = batch_mean.data
59
+ self.batch_std.data = batch_std.data
60
+ else:
61
+ raise ValueError('parital FC margin_softmax not supported type')
62
+
63
+ loss = self.cross_entropy(logits, labels)
64
+ return loss
65
+
66
+
67
+
cvlface/research/recognition/code/run_v1/classifiers/partial_fc/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..base import BaseClassifier
2
+ from .partial_fc import PartialFC_V2
3
+
4
+
5
+ class PartialFCClassifier(BaseClassifier):
6
+
7
+ def __init__(self, classifier, config, rank, world_size):
8
+ super(PartialFCClassifier, self).__init__()
9
+ self.partial_fc = classifier
10
+ self.config = config
11
+ self.rank = rank
12
+ self.world_size = world_size
13
+ self.apply_ddp = False
14
+
15
+ @classmethod
16
+ def from_config(cls, classifier_cfg, margin_loss_fn, model_cfg, num_classes, rank, world_size):
17
+ if classifier_cfg.name == 'partial_fc':
18
+ classifier = PartialFC_V2(
19
+ rank=rank,
20
+ world_size=world_size,
21
+ margin_loss=margin_loss_fn,
22
+ embedding_size=model_cfg.output_dim,
23
+ num_classes=num_classes,
24
+ sample_rate=classifier_cfg.sample_rate,
25
+ )
26
+ else:
27
+ raise NotImplementedError
28
+
29
+ model = cls(classifier, classifier_cfg, rank, world_size)
30
+ model.eval()
31
+ return model
32
+
33
+ def forward(self, local_embeddings, local_labels):
34
+ loss = self.partial_fc(local_embeddings, local_labels)
35
+ return loss
36
+
37
+
38
+
39
+
cvlface/research/recognition/code/run_v1/classifiers/partial_fc/partial_fc.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+ import torch
3
+ from torch import distributed
4
+ from torch.nn.functional import linear, normalize
5
+ from losses.margin_loss import CombinedMarginLoss
6
+ from losses.adaface import AdaFaceLoss
7
+
8
+
9
+
10
+ class PartialFC_V2(torch.nn.Module):
11
+ """
12
+ https://arxiv.org/abs/2203.15565
13
+ A distributed sparsely updating variant of the FC layer, named Partial FC (PFC).
14
+ When sample rate less than 1, in each iteration, positive class centers and a random subset of
15
+ negative class centers are selected to compute the margin-based softmax loss, all class
16
+ centers are still maintained throughout the whole training process, but only a subset is
17
+ selected and updated in each iteration.
18
+ .. note::
19
+ When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1).
20
+ Example:
21
+ --------
22
+ >>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2)
23
+ >>> for img, labels in data_loader:
24
+ >>> embeddings = net(img)
25
+ >>> loss = module_pfc(embeddings, labels)
26
+ >>> loss.backward()
27
+ >>> optimizer.step()
28
+ """
29
+ _version = 2
30
+
31
+ def __init__(
32
+ self,
33
+ rank: int,
34
+ world_size: int,
35
+ margin_loss: Callable,
36
+ embedding_size: int,
37
+ num_classes: int,
38
+ sample_rate: float = 1.0,
39
+ ):
40
+ """
41
+ Paramenters:
42
+ -----------
43
+ embedding_size: int
44
+ The dimension of embedding, required
45
+ num_classes: int
46
+ Total number of classes, required
47
+ sample_rate: float
48
+ The rate of negative centers participating in the calculation, default is 1.0.
49
+ """
50
+ super(PartialFC_V2, self).__init__()
51
+ assert (
52
+ distributed.is_initialized()
53
+ ), "must initialize distributed before create this"
54
+ self.rank = rank
55
+ self.world_size = world_size
56
+
57
+ self.dist_cross_entropy = DistCrossEntropy()
58
+ self.embedding_size = embedding_size
59
+ self.sample_rate: float = sample_rate
60
+
61
+ # make num_class divisible by self.world_size for ddp
62
+ _num_classes = num_classes // self.world_size * self.world_size
63
+ if _num_classes < num_classes:
64
+ _num_classes = _num_classes + self.world_size
65
+ num_classes = _num_classes
66
+ self.num_classes: int = num_classes
67
+
68
+ self.num_local: int = num_classes // self.world_size + int(
69
+ self.rank < num_classes % self.world_size
70
+ )
71
+
72
+ # for i in range(8):
73
+ # num_local = (num_classes // self.world_size + int( i < num_classes % self.world_size ))
74
+ # class_start = num_classes // self.world_size * i + min( i, num_classes % self.world_size )
75
+ # print(num_local, class_start)
76
+
77
+ self.class_start: int = num_classes // self.world_size * self.rank + min(
78
+ self.rank, num_classes % self.world_size
79
+ )
80
+ self.num_sample: int = int(self.sample_rate * self.num_local)
81
+ self.last_batch_size: int = 0
82
+
83
+ self.is_updated: bool = True
84
+ self.init_weight_update: bool = True
85
+ self.weight = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))
86
+
87
+ # margin_loss
88
+ if isinstance(margin_loss, Callable):
89
+ self.margin_softmax = margin_loss
90
+ if isinstance(margin_loss, AdaFaceLoss):
91
+ self.register_buffer('batch_mean', torch.ones(1)*(20))
92
+ self.register_buffer('batch_std', torch.ones(1)*100)
93
+ else:
94
+ raise
95
+
96
+ def sample(self, labels, index_positive):
97
+ """
98
+ This functions will change the value of labels
99
+ Parameters:
100
+ -----------
101
+ labels: torch.Tensor
102
+ pass
103
+ index_positive: torch.Tensor
104
+ pass
105
+ optimizer: torch.optim.Optimizer
106
+ pass
107
+ """
108
+ with torch.no_grad():
109
+ positive = torch.unique(labels[index_positive], sorted=True).cuda()
110
+ if self.num_sample - positive.size(0) >= 0:
111
+ perm = torch.rand(size=[self.num_local]).cuda()
112
+ perm[positive] = 2.0
113
+ index = torch.topk(perm, k=self.num_sample)[1].cuda()
114
+ index = index.sort()[0].cuda()
115
+ else:
116
+ index = positive
117
+ self.weight_index = index
118
+
119
+ labels[index_positive] = torch.searchsorted(index, labels[index_positive])
120
+
121
+ return self.weight[self.weight_index]
122
+
123
+ def forward(
124
+ self,
125
+ local_embeddings: torch.Tensor,
126
+ local_labels: torch.Tensor,
127
+ ):
128
+ """
129
+ Parameters:
130
+ ----------
131
+ local_embeddings: torch.Tensor
132
+ feature embeddings on each GPU(Rank).
133
+ local_labels: torch.Tensor
134
+ labels on each GPU(Rank).
135
+ Returns:
136
+ -------
137
+ loss: torch.Tensor
138
+ pass
139
+ """
140
+
141
+ local_labels.squeeze_()
142
+ local_labels = local_labels.long()
143
+
144
+ batch_size = local_embeddings.size(0)
145
+ if self.last_batch_size == 0:
146
+ self.last_batch_size = batch_size
147
+ assert self.last_batch_size == batch_size, (
148
+ f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}")
149
+
150
+ _gather_embeddings = [
151
+ torch.zeros((batch_size, self.embedding_size), dtype=local_embeddings.dtype, device=local_embeddings.device)
152
+ for _ in range(self.world_size)
153
+ ]
154
+ _gather_labels = [
155
+ torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)
156
+ ]
157
+ _list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
158
+ distributed.all_gather(_gather_labels, local_labels)
159
+
160
+ embeddings = torch.cat(_list_embeddings)
161
+ labels = torch.cat(_gather_labels)
162
+
163
+ labels = labels.view(-1, 1)
164
+ index_positive = (self.class_start <= labels) & (
165
+ labels < self.class_start + self.num_local
166
+ )
167
+ labels[~index_positive] = -1
168
+ labels[index_positive] -= self.class_start
169
+
170
+ if self.sample_rate < 1:
171
+ weight = self.sample(labels, index_positive)
172
+ else:
173
+ weight = self.weight
174
+
175
+ # with torch.cuda.amp.autocast(self.fp16):
176
+ norms = embeddings.norm(p=2, dim=1, keepdim=True).clamp_min(1e-8)
177
+ norm_embeddings = embeddings / norms
178
+
179
+ norm_weight_activated = normalize(weight)
180
+ logits = linear(norm_embeddings, norm_weight_activated)
181
+
182
+ logits = logits.clamp(-1, 1)
183
+
184
+ if isinstance(self.margin_softmax, CombinedMarginLoss):
185
+ logits = self.margin_softmax(logits=logits, labels=labels)
186
+ elif isinstance(self.margin_softmax, AdaFaceLoss):
187
+ logits, batch_mean, batch_std = self.margin_softmax(logits=logits, labels=labels, norms=norms,
188
+ batch_mean=self.batch_mean,
189
+ batch_std=self.batch_std)
190
+ self.batch_mean.data = batch_mean.data
191
+ self.batch_std.data = batch_std.data
192
+ else:
193
+ raise ValueError('parital FC margin_softmax not supported type')
194
+
195
+ loss = self.dist_cross_entropy(logits, labels)
196
+ return loss
197
+
198
+
199
+ class DistCrossEntropyFunc(torch.autograd.Function):
200
+ """
201
+ CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax.
202
+ Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
203
+ """
204
+
205
+ @staticmethod
206
+ def forward(ctx, logits: torch.Tensor, label: torch.Tensor):
207
+ """ """
208
+ batch_size = logits.size(0)
209
+ # for numerical stability
210
+ max_logits, _ = torch.max(logits, dim=1, keepdim=True)
211
+ # local to global
212
+ distributed.all_reduce(max_logits, distributed.ReduceOp.MAX)
213
+ logits.sub_(max_logits)
214
+ logits.exp_()
215
+ sum_logits_exp = torch.sum(logits, dim=1, keepdim=True)
216
+ # local to global
217
+ distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM)
218
+ logits.div_(sum_logits_exp)
219
+ index = torch.where(label != -1)[0]
220
+ # loss
221
+ loss = torch.zeros(batch_size, 1, device=logits.device, dtype=logits.dtype)
222
+ loss[index] = logits[index].gather(1, label[index])
223
+ distributed.all_reduce(loss, distributed.ReduceOp.SUM)
224
+ ctx.save_for_backward(index, logits, label)
225
+ return loss.clamp_min_(1e-30).log_().mean() * (-1)
226
+
227
+ @staticmethod
228
+ def backward(ctx, loss_gradient):
229
+ """
230
+ Args:
231
+ loss_grad (torch.Tensor): gradient backward by last layer
232
+ Returns:
233
+ gradients for each input in forward function
234
+ `None` gradients for one-hot label
235
+ """
236
+ (
237
+ index,
238
+ logits,
239
+ label,
240
+ ) = ctx.saved_tensors
241
+ batch_size = logits.size(0)
242
+ one_hot = torch.zeros(
243
+ size=[index.size(0), logits.size(1)], device=logits.device
244
+ )
245
+ one_hot.scatter_(1, label[index], 1)
246
+ logits[index] -= one_hot
247
+ logits.div_(batch_size)
248
+ return logits * loss_gradient.item(), None
249
+
250
+
251
+ class DistCrossEntropy(torch.nn.Module):
252
+ def __init__(self):
253
+ super(DistCrossEntropy, self).__init__()
254
+
255
+ def forward(self, logit_part, label_part):
256
+ return DistCrossEntropyFunc.apply(logit_part, label_part)
257
+
258
+
259
+ class AllGatherFunc(torch.autograd.Function):
260
+ """AllGather op with gradient backward"""
261
+
262
+ @staticmethod
263
+ def forward(ctx, tensor, *gather_list):
264
+ gather_list = list(gather_list)
265
+ distributed.all_gather(gather_list, tensor)
266
+ return tuple(gather_list)
267
+
268
+ @staticmethod
269
+ def backward(ctx, *grads):
270
+ grad_list = list(grads)
271
+ rank = distributed.get_rank()
272
+ grad_out = grad_list[rank]
273
+
274
+ dist_ops = [
275
+ distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True)
276
+ if i == rank
277
+ else distributed.reduce(
278
+ grad_list[i], i, distributed.ReduceOp.SUM, async_op=True
279
+ )
280
+ for i in range(distributed.get_world_size())
281
+ ]
282
+ for _op in dist_ops:
283
+ _op.wait()
284
+
285
+ grad_out *= len(grad_list) # cooperate with distributed loss function
286
+ return (grad_out, *[None for _ in range(len(grad_list))])
287
+
288
+
289
+ AllGather = AllGatherFunc.apply