minchul commited on
Commit
ce4d684
·
verified ·
1 Parent(s): d435aa8

Upload directory

Browse files
aligners/differentiable_face_aligner/__init__.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..base import BaseAligner
2
+ from torchvision import transforms
3
+ from .dfa import get_landmark_predictor, get_preprocessor
4
+ from . import aligner_helper
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+
9
+
10
+ class DifferentiableFaceAligner(BaseAligner):
11
+
12
+ '''
13
+ A differentiable face aligner that aligns the image with one face to a canonical position.
14
+ The aligner is based on the following paper (check out supplementary material for more details):
15
+ @inproceedings{kim2024kprpe,
16
+ title={{KeyPoint Relative Position Encoding for Face Recognition},
17
+ author={Kim, Minchul and Su, Yiyang and Liu, Feng and Liu, Xiaoming},
18
+ booktitle={CVPR},
19
+ year={2024}
20
+ }
21
+ '''
22
+
23
+ def __init__(self, net, prior_box, preprocessor, config):
24
+ super(DifferentiableFaceAligner, self).__init__()
25
+ self.net = net
26
+ self.prior_box = prior_box
27
+ self.preprocessor = preprocessor
28
+ self.config = config
29
+
30
+ @classmethod
31
+ def from_config(cls, config):
32
+ net, prior_box = get_landmark_predictor(network=config.arch,
33
+ use_aggregator=True,
34
+ input_size=config.input_size)
35
+
36
+ preprocessor = get_preprocessor(output_size=config.input_size,
37
+ padding=config.input_padding_ratio,
38
+ padding_val=config.input_padding_val)
39
+ if config.freeze:
40
+ for param in net.parameters():
41
+ param.requires_grad = False
42
+ model = cls(net, prior_box, preprocessor, config)
43
+ model.eval()
44
+ return model
45
+
46
+ def forward(self, x, padding_ratio_override=None):
47
+
48
+ # input size check
49
+ assert x.shape[1] == 3
50
+ assert x.ndim == 4
51
+ assert isinstance(x, torch.Tensor)
52
+ is_square = x.shape[2] == x.shape[3]
53
+
54
+ x = self.preprocessor(x, padding_ratio_override=padding_ratio_override)
55
+ assert self.prior_box.image_size == x.shape[2:]
56
+
57
+ # make image into BGR
58
+ x_bgr = x.flip(1)
59
+ result = self.net(x_bgr, self.prior_box)
60
+ orig_pred_ldmks, bbox, cls = aligner_helper.split_network_output(result)
61
+ score = torch.nn.Softmax(dim=-1)(cls)[:,1:]
62
+
63
+ reference_ldmk = aligner_helper.reference_landmark()
64
+ input_size = self.config.input_size
65
+ output_size = self.config.output_size
66
+ cv2_tfms = aligner_helper.get_cv2_affine_from_landmark(orig_pred_ldmks, reference_ldmk, input_size, input_size)
67
+ thetas = aligner_helper.cv2_param_to_torch_theta(cv2_tfms, input_size, input_size, output_size, output_size)
68
+ thetas = thetas.to(orig_pred_ldmks.device)
69
+
70
+ output_size = torch.Size((len(thetas), 3, output_size, output_size))
71
+ grid = F.affine_grid(thetas, output_size, align_corners=True)
72
+ aligned_x = F.grid_sample(x + 1, grid, align_corners=True) - 1 # +1, -1 for making padding pixel 0
73
+ aligned_ldmks = aligner_helper.adjust_ldmks(orig_pred_ldmks.view(-1, 5, 2), thetas)
74
+
75
+ orig_pred_ldmks = orig_pred_ldmks.view(-1, 5, 2)
76
+ # bbox (xmin, ymin, xmax, ymax)
77
+ normalized_bbox = bbox / torch.tensor([[x_bgr.size(3), x_bgr.size(2)] * 2]).to(bbox.device)
78
+
79
+
80
+ if padding_ratio_override is None:
81
+ padding_ratio = self.preprocessor.padding
82
+ else:
83
+ padding_ratio = padding_ratio_override
84
+ if padding_ratio > 0:
85
+ # unpad the landmark so that it is in the original image coordinate
86
+ scale = 1 / (1 + (2 * padding_ratio))
87
+ pad_inv_theta = torch.from_numpy(np.array([[1 / scale, 0, 0], [0, 1 / scale, 0]]))
88
+ pad_inv_theta = pad_inv_theta.unsqueeze(0).float().to(self.device).repeat(orig_pred_ldmks.size(0), 1, 1)
89
+ unpad_ldmk_pred = torch.concat([orig_pred_ldmks.view(-1, 5, 2),
90
+ torch.ones((orig_pred_ldmks.size(0), 5, 1)).to(self.device)], dim=-1)
91
+ unpad_ldmk_pred = (((unpad_ldmk_pred) * 2 - 1) @ pad_inv_theta.mT) / 2 + 0.5
92
+ unpad_ldmk_pred = unpad_ldmk_pred.view(orig_pred_ldmks.size(0), -1).detach()
93
+ unpad_ldmk_pred = unpad_ldmk_pred.view(-1, 5, 2)
94
+ if not is_square:
95
+ unpad_ldmk_pred = None # cannot use this if the input is not square becaouse preprocessor changes input
96
+ normalized_bbox = None # cannot use this if the input is not square becaouse preprocessor changes input
97
+ return aligned_x, unpad_ldmk_pred, aligned_ldmks, score, thetas, normalized_bbox
98
+
99
+ if not is_square:
100
+ orig_pred_ldmks = None # cannot use this if the input is not square becaouse preprocessor changes input
101
+ normalized_bbox = None # cannot use this if the input is not square becaouse preprocessor changes input
102
+ return aligned_x, orig_pred_ldmks, aligned_ldmks, score, thetas, normalized_bbox
103
+
104
+ def make_train_transform(self):
105
+ transform = transforms.Compose([
106
+ transforms.ToTensor(),
107
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
108
+ ])
109
+ return transform
110
+
111
+ def make_test_transform(self):
112
+ transform = transforms.Compose([
113
+ transforms.ToTensor(),
114
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
115
+ ])
116
+ return transform
117
+