Upload directory
Browse files
aligners/differentiable_face_aligner/__init__.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..base import BaseAligner
|
| 2 |
+
from torchvision import transforms
|
| 3 |
+
from .dfa import get_landmark_predictor, get_preprocessor
|
| 4 |
+
from . import aligner_helper
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DifferentiableFaceAligner(BaseAligner):
|
| 11 |
+
|
| 12 |
+
'''
|
| 13 |
+
A differentiable face aligner that aligns the image with one face to a canonical position.
|
| 14 |
+
The aligner is based on the following paper (check out supplementary material for more details):
|
| 15 |
+
@inproceedings{kim2024kprpe,
|
| 16 |
+
title={{KeyPoint Relative Position Encoding for Face Recognition},
|
| 17 |
+
author={Kim, Minchul and Su, Yiyang and Liu, Feng and Liu, Xiaoming},
|
| 18 |
+
booktitle={CVPR},
|
| 19 |
+
year={2024}
|
| 20 |
+
}
|
| 21 |
+
'''
|
| 22 |
+
|
| 23 |
+
def __init__(self, net, prior_box, preprocessor, config):
|
| 24 |
+
super(DifferentiableFaceAligner, self).__init__()
|
| 25 |
+
self.net = net
|
| 26 |
+
self.prior_box = prior_box
|
| 27 |
+
self.preprocessor = preprocessor
|
| 28 |
+
self.config = config
|
| 29 |
+
|
| 30 |
+
@classmethod
|
| 31 |
+
def from_config(cls, config):
|
| 32 |
+
net, prior_box = get_landmark_predictor(network=config.arch,
|
| 33 |
+
use_aggregator=True,
|
| 34 |
+
input_size=config.input_size)
|
| 35 |
+
|
| 36 |
+
preprocessor = get_preprocessor(output_size=config.input_size,
|
| 37 |
+
padding=config.input_padding_ratio,
|
| 38 |
+
padding_val=config.input_padding_val)
|
| 39 |
+
if config.freeze:
|
| 40 |
+
for param in net.parameters():
|
| 41 |
+
param.requires_grad = False
|
| 42 |
+
model = cls(net, prior_box, preprocessor, config)
|
| 43 |
+
model.eval()
|
| 44 |
+
return model
|
| 45 |
+
|
| 46 |
+
def forward(self, x, padding_ratio_override=None):
|
| 47 |
+
|
| 48 |
+
# input size check
|
| 49 |
+
assert x.shape[1] == 3
|
| 50 |
+
assert x.ndim == 4
|
| 51 |
+
assert isinstance(x, torch.Tensor)
|
| 52 |
+
is_square = x.shape[2] == x.shape[3]
|
| 53 |
+
|
| 54 |
+
x = self.preprocessor(x, padding_ratio_override=padding_ratio_override)
|
| 55 |
+
assert self.prior_box.image_size == x.shape[2:]
|
| 56 |
+
|
| 57 |
+
# make image into BGR
|
| 58 |
+
x_bgr = x.flip(1)
|
| 59 |
+
result = self.net(x_bgr, self.prior_box)
|
| 60 |
+
orig_pred_ldmks, bbox, cls = aligner_helper.split_network_output(result)
|
| 61 |
+
score = torch.nn.Softmax(dim=-1)(cls)[:,1:]
|
| 62 |
+
|
| 63 |
+
reference_ldmk = aligner_helper.reference_landmark()
|
| 64 |
+
input_size = self.config.input_size
|
| 65 |
+
output_size = self.config.output_size
|
| 66 |
+
cv2_tfms = aligner_helper.get_cv2_affine_from_landmark(orig_pred_ldmks, reference_ldmk, input_size, input_size)
|
| 67 |
+
thetas = aligner_helper.cv2_param_to_torch_theta(cv2_tfms, input_size, input_size, output_size, output_size)
|
| 68 |
+
thetas = thetas.to(orig_pred_ldmks.device)
|
| 69 |
+
|
| 70 |
+
output_size = torch.Size((len(thetas), 3, output_size, output_size))
|
| 71 |
+
grid = F.affine_grid(thetas, output_size, align_corners=True)
|
| 72 |
+
aligned_x = F.grid_sample(x + 1, grid, align_corners=True) - 1 # +1, -1 for making padding pixel 0
|
| 73 |
+
aligned_ldmks = aligner_helper.adjust_ldmks(orig_pred_ldmks.view(-1, 5, 2), thetas)
|
| 74 |
+
|
| 75 |
+
orig_pred_ldmks = orig_pred_ldmks.view(-1, 5, 2)
|
| 76 |
+
# bbox (xmin, ymin, xmax, ymax)
|
| 77 |
+
normalized_bbox = bbox / torch.tensor([[x_bgr.size(3), x_bgr.size(2)] * 2]).to(bbox.device)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
if padding_ratio_override is None:
|
| 81 |
+
padding_ratio = self.preprocessor.padding
|
| 82 |
+
else:
|
| 83 |
+
padding_ratio = padding_ratio_override
|
| 84 |
+
if padding_ratio > 0:
|
| 85 |
+
# unpad the landmark so that it is in the original image coordinate
|
| 86 |
+
scale = 1 / (1 + (2 * padding_ratio))
|
| 87 |
+
pad_inv_theta = torch.from_numpy(np.array([[1 / scale, 0, 0], [0, 1 / scale, 0]]))
|
| 88 |
+
pad_inv_theta = pad_inv_theta.unsqueeze(0).float().to(self.device).repeat(orig_pred_ldmks.size(0), 1, 1)
|
| 89 |
+
unpad_ldmk_pred = torch.concat([orig_pred_ldmks.view(-1, 5, 2),
|
| 90 |
+
torch.ones((orig_pred_ldmks.size(0), 5, 1)).to(self.device)], dim=-1)
|
| 91 |
+
unpad_ldmk_pred = (((unpad_ldmk_pred) * 2 - 1) @ pad_inv_theta.mT) / 2 + 0.5
|
| 92 |
+
unpad_ldmk_pred = unpad_ldmk_pred.view(orig_pred_ldmks.size(0), -1).detach()
|
| 93 |
+
unpad_ldmk_pred = unpad_ldmk_pred.view(-1, 5, 2)
|
| 94 |
+
if not is_square:
|
| 95 |
+
unpad_ldmk_pred = None # cannot use this if the input is not square becaouse preprocessor changes input
|
| 96 |
+
normalized_bbox = None # cannot use this if the input is not square becaouse preprocessor changes input
|
| 97 |
+
return aligned_x, unpad_ldmk_pred, aligned_ldmks, score, thetas, normalized_bbox
|
| 98 |
+
|
| 99 |
+
if not is_square:
|
| 100 |
+
orig_pred_ldmks = None # cannot use this if the input is not square becaouse preprocessor changes input
|
| 101 |
+
normalized_bbox = None # cannot use this if the input is not square becaouse preprocessor changes input
|
| 102 |
+
return aligned_x, orig_pred_ldmks, aligned_ldmks, score, thetas, normalized_bbox
|
| 103 |
+
|
| 104 |
+
def make_train_transform(self):
|
| 105 |
+
transform = transforms.Compose([
|
| 106 |
+
transforms.ToTensor(),
|
| 107 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 108 |
+
])
|
| 109 |
+
return transform
|
| 110 |
+
|
| 111 |
+
def make_test_transform(self):
|
| 112 |
+
transform = transforms.Compose([
|
| 113 |
+
transforms.ToTensor(),
|
| 114 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 115 |
+
])
|
| 116 |
+
return transform
|
| 117 |
+
|