| | import torch |
| | import numpy as np |
| | from einops import rearrange |
| | from kornia.geometry.transform.crop2d import warp_affine |
| |
|
| | from utils.matlab_cp2tform import get_similarity_transform_for_cv2 |
| | from torchvision.transforms import Pad |
| |
|
| | REFERNCE_FACIAL_POINTS_RELATIVE = np.array([[38.29459953, 51.69630051], |
| | [72.53179932, 51.50139999], |
| | [56.02519989, 71.73660278], |
| | [41.54930115, 92.3655014], |
| | [70.72990036, 92.20410156] |
| | ]) / 112 |
| |
|
| |
|
| | @torch.no_grad() |
| | def detect_face(images: torch.Tensor, mtcnn: torch.nn.Module) -> torch.Tensor: |
| | """ |
| | Detect faces in the images using MTCNN. If no face is detected, use the whole image. |
| | """ |
| | images = rearrange(images, "b c h w -> b h w c") |
| | if images.dtype != torch.uint8: |
| | images = ((images * 0.5 + 0.5) * 255).type(torch.uint8) |
| | |
| | _, _, landmarks = mtcnn(images, landmarks=True) |
| |
|
| | return landmarks |
| |
|
| |
|
| | def extract_faces_and_landmarks(images: torch.Tensor, output_size=112, mtcnn: torch.nn.Module = None, refernce_points=REFERNCE_FACIAL_POINTS_RELATIVE): |
| | """ |
| | detect faces in the images and crop them (in a differentiable way) to 112x112 using MTCNN. |
| | """ |
| | images = Pad(200)(images) |
| | landmarks_batched = detect_face(images, mtcnn=mtcnn) |
| | affine_transformations = [] |
| | invalid_indices = [] |
| | for i, landmarks in enumerate(landmarks_batched): |
| | if landmarks is None: |
| | invalid_indices.append(i) |
| | affine_transformations.append(np.eye(2, 3).astype(np.float32)) |
| | else: |
| | affine_transformations.append(get_similarity_transform_for_cv2(landmarks[0].astype(np.float32), |
| | refernce_points.astype(np.float32) * output_size)) |
| | affine_transformations = torch.from_numpy(np.stack(affine_transformations).astype(np.float32)).to(device=images.device, dtype=torch.float32) |
| |
|
| | invalid_indices = torch.tensor(invalid_indices).to(device=images.device) |
| |
|
| | fp_images = images.to(torch.float32) |
| | return warp_affine(fp_images, affine_transformations, dsize=(output_size, output_size)).to(dtype=images.dtype), invalid_indices |