| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from tqdm import tqdm |
| import copy |
| import argparse |
| import torch |
| import math |
| import cv2 |
| import numpy as np |
| import dlib |
|
|
| from star.lib import utility |
| from star.asset import predictor_path, model_path |
|
|
| from vhap.util.log import get_logger |
| logger = get_logger(__name__) |
|
|
|
|
| class GetCropMatrix(): |
| """ |
| from_shape -> transform_matrix |
| """ |
|
|
| def __init__(self, image_size, target_face_scale, align_corners=False): |
| self.image_size = image_size |
| self.target_face_scale = target_face_scale |
| self.align_corners = align_corners |
|
|
| def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center): |
| cosv = math.cos(angle) |
| sinv = math.sin(angle) |
|
|
| fx, fy = from_center |
| tx, ty = to_center |
|
|
| acos = scale * cosv |
| asin = scale * sinv |
|
|
| a0 = acos |
| a1 = -asin |
| a2 = tx - acos * fx + asin * fy + shift_xy[0] |
|
|
| b0 = asin |
| b1 = acos |
| b2 = ty - asin * fx - acos * fy + shift_xy[1] |
|
|
| rot_scale_m = np.array([ |
| [a0, a1, a2], |
| [b0, b1, b2], |
| [0.0, 0.0, 1.0] |
| ], np.float32) |
| return rot_scale_m |
|
|
| def process(self, scale, center_w, center_h): |
| if self.align_corners: |
| to_w, to_h = self.image_size - 1, self.image_size - 1 |
| else: |
| to_w, to_h = self.image_size, self.image_size |
|
|
| rot_mu = 0 |
| scale_mu = self.image_size / (scale * self.target_face_scale * 200.0) |
| shift_xy_mu = (0, 0) |
| matrix = self._compose_rotate_and_scale( |
| rot_mu, scale_mu, shift_xy_mu, |
| from_center=[center_w, center_h], |
| to_center=[to_w / 2.0, to_h / 2.0]) |
| return matrix |
|
|
|
|
| class TransformPerspective(): |
| """ |
| image, matrix3x3 -> transformed_image |
| """ |
|
|
| def __init__(self, image_size): |
| self.image_size = image_size |
|
|
| def process(self, image, matrix): |
| return cv2.warpPerspective( |
| image, matrix, dsize=(self.image_size, self.image_size), |
| flags=cv2.INTER_LINEAR, borderValue=0) |
|
|
|
|
| class TransformPoints2D(): |
| """ |
| points (nx2), matrix (3x3) -> points (nx2) |
| """ |
|
|
| def process(self, srcPoints, matrix): |
| |
| desPoints = np.concatenate([srcPoints, np.ones_like(srcPoints[:, [0]])], axis=1) |
| desPoints = desPoints @ np.transpose(matrix) |
| desPoints = desPoints[:, :2] / desPoints[:, [2, 2]] |
| return desPoints.astype(srcPoints.dtype) |
|
|
|
|
| class Alignment: |
| def __init__(self, args, model_path, dl_framework, device_ids): |
| self.input_size = 256 |
| self.target_face_scale = 1.0 |
| self.dl_framework = dl_framework |
|
|
| |
| if self.dl_framework == "pytorch": |
| |
| self.config = utility.get_config(args) |
| self.config.device_id = device_ids[0] |
| |
| utility.set_environment(self.config) |
| self.config.init_instance() |
| if self.config.logger is not None: |
| self.config.logger.info("Loaded configure file %s: %s" % (args.config_name, self.config.id)) |
| self.config.logger.info("\n" + "\n".join(["%s: %s" % item for item in self.config.__dict__.items()])) |
|
|
| net = utility.get_net(self.config) |
| if device_ids == [-1]: |
| checkpoint = torch.load(model_path, map_location="cpu") |
| else: |
| checkpoint = torch.load(model_path) |
| net.load_state_dict(checkpoint["net"]) |
| net = net.to(self.config.device_id) |
| net.eval() |
| self.alignment = net |
| else: |
| assert False |
|
|
| self.getCropMatrix = GetCropMatrix(image_size=self.input_size, target_face_scale=self.target_face_scale, |
| align_corners=True) |
| self.transformPerspective = TransformPerspective(image_size=self.input_size) |
| self.transformPoints2D = TransformPoints2D() |
|
|
| def norm_points(self, points, align_corners=False): |
| if align_corners: |
| |
| return points / torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2) * 2 - 1 |
| else: |
| |
| return (points * 2 + 1) / torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1 |
|
|
| def denorm_points(self, points, align_corners=False): |
| if align_corners: |
| |
| return (points + 1) / 2 * torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2) |
| else: |
| |
| return ((points + 1) * torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1) / 2 |
|
|
| def preprocess(self, image, scale, center_w, center_h): |
| matrix = self.getCropMatrix.process(scale, center_w, center_h) |
| input_tensor = self.transformPerspective.process(image, matrix) |
| input_tensor = input_tensor[np.newaxis, :] |
|
|
| input_tensor = torch.from_numpy(input_tensor) |
| input_tensor = input_tensor.float().permute(0, 3, 1, 2) |
| input_tensor = input_tensor / 255.0 * 2.0 - 1.0 |
| input_tensor = input_tensor.to(self.config.device_id) |
| return input_tensor, matrix |
|
|
| def postprocess(self, srcPoints, coeff): |
| |
| |
| |
| dstPoints = np.zeros(srcPoints.shape, dtype=np.float32) |
| for i in range(srcPoints.shape[0]): |
| dstPoints[i][0] = coeff[0][0] * srcPoints[i][0] + coeff[0][1] * srcPoints[i][1] + coeff[0][2] |
| dstPoints[i][1] = coeff[1][0] * srcPoints[i][0] + coeff[1][1] * srcPoints[i][1] + coeff[1][2] |
| return dstPoints |
|
|
| def analyze(self, image, scale, center_w, center_h): |
| input_tensor, matrix = self.preprocess(image, scale, center_w, center_h) |
|
|
| if self.dl_framework == "pytorch": |
| with torch.no_grad(): |
| output = self.alignment(input_tensor) |
| landmarks = output[-1][0] |
| else: |
| assert False |
|
|
| landmarks = self.denorm_points(landmarks) |
| landmarks = landmarks.data.cpu().numpy()[0] |
| landmarks = self.postprocess(landmarks, np.linalg.inv(matrix)) |
|
|
| return landmarks |
|
|
|
|
| def draw_pts(img, pts, mode="pts", shift=4, color=(0, 255, 0), radius=1, thickness=1, save_path=None, dif=0, |
| scale=0.3, concat=False, ): |
| img_draw = copy.deepcopy(img) |
| for cnt, p in enumerate(pts): |
| if mode == "index": |
| cv2.putText(img_draw, str(cnt), (int(float(p[0] + dif)), int(float(p[1] + dif))), cv2.FONT_HERSHEY_SIMPLEX, |
| scale, color, thickness) |
| elif mode == 'pts': |
| if len(img_draw.shape) > 2: |
| |
| img_draw = cv2.cvtColor(img_draw, cv2.COLOR_BGR2RGB) |
| img_draw = cv2.cvtColor(img_draw, cv2.COLOR_RGB2BGR) |
| cv2.circle(img_draw, (int(p[0] * (1 << shift)), int(p[1] * (1 << shift))), radius << shift, color, -1, |
| cv2.LINE_AA, shift=shift) |
| else: |
| raise NotImplementedError |
| if concat: |
| img_draw = np.concatenate((img, img_draw), axis=1) |
| if save_path is not None: |
| cv2.imwrite(save_path, img_draw) |
| return img_draw |
|
|
|
|
| class LandmarkDetectorSTAR: |
| def __init__( |
| self, |
| ): |
| self.detector = dlib.get_frontal_face_detector() |
| self.shape_predictor = dlib.shape_predictor(predictor_path) |
|
|
| |
| args = argparse.Namespace() |
| args.config_name = 'alignment' |
| |
| |
| device_ids = '0' |
| device_ids = list(map(int, device_ids.split(","))) |
| self.alignment = Alignment(args, model_path, dl_framework="pytorch", device_ids=device_ids) |
|
|
| def detect_single_image(self, img): |
| bbox = self.detector(img, 1) |
|
|
| if len(bbox) == 0: |
| bbox = np.zeros(5) - 1 |
| lmks = np.zeros([68, 3]) - 1 |
| else: |
| face = self.shape_predictor(img, bbox[0]) |
| shape = [] |
| for i in range(68): |
| x = face.part(i).x |
| y = face.part(i).y |
| shape.append((x, y)) |
| shape = np.array(shape) |
| x1, x2 = shape[:, 0].min(), shape[:, 0].max() |
| y1, y2 = shape[:, 1].min(), shape[:, 1].max() |
| scale = min(x2 - x1, y2 - y1) / 200 * 1.05 |
| center_w = (x2 + x1) / 2 |
| center_h = (y2 + y1) / 2 |
|
|
| scale, center_w, center_h = float(scale), float(center_w), float(center_h) |
| lmks = self.alignment.analyze(img, scale, center_w, center_h) |
|
|
| h, w = img.shape[:2] |
|
|
| lmks = np.concatenate([lmks, np.ones([lmks.shape[0], 1])], axis=1).astype(np.float32) |
| lmks[:, 0] /= w |
| lmks[:, 1] /= h |
|
|
| bbox = np.array([bbox[0].left(), bbox[0].top(), bbox[0].right(), bbox[0].bottom(), 1.]).astype(np.float32) |
| bbox[[0, 2]] /= w |
| bbox[[1, 3]] /= h |
|
|
| return bbox, lmks |
|
|
| def detect_dataset(self, dataloader): |
| """ |
| Annotates each frame with 68 facial landmarks |
| :return: dict mapping frame number to landmarks numpy array and the same thing for bboxes |
| """ |
| logger.info("Initialize Landmark Detector (STAR)...") |
| |
|
|
| landmarks = {} |
| bboxes = {} |
|
|
| logger.info("Begin annotating landmarks...") |
| for item in tqdm(dataloader): |
| timestep_id = item["timestep_id"][0] |
| camera_id = item["camera_id"][0] |
|
|
| logger.info( |
| f"Annotate facial landmarks for timestep: {timestep_id}, camera: {camera_id}" |
| ) |
| img = item["rgb"][0].numpy() |
|
|
| bbox, lmks = self.detect_single_image(img) |
| if len(bbox) == 0: |
| logger.error( |
| f"No bbox found for frame: {timestep_id}, camera: {camera_id}. Setting landmarks to all -1." |
| ) |
|
|
| if camera_id not in landmarks: |
| landmarks[camera_id] = {} |
| if camera_id not in bboxes: |
| bboxes[camera_id] = {} |
| landmarks[camera_id][timestep_id] = lmks |
| bboxes[camera_id][timestep_id] = bbox |
| return landmarks, bboxes |
|
|
| def annotate_landmarks(self, dataloader): |
| """ |
| Annotates each frame with landmarks for face and iris. Assumes frames have been extracted |
| :return: |
| """ |
| lmks_face, bboxes_faces = self.detect_dataset(dataloader) |
|
|
| |
| for camera_id, lmk_face_camera in lmks_face.items(): |
| bounding_box = [] |
| face_landmark_2d = [] |
| for timestep_id in lmk_face_camera.keys(): |
| bounding_box.append(bboxes_faces[camera_id][timestep_id][None]) |
| face_landmark_2d.append(lmks_face[camera_id][timestep_id][None]) |
|
|
| lmk_dict = { |
| "bounding_box": bounding_box, |
| "face_landmark_2d": face_landmark_2d, |
| } |
|
|
| for k, v in lmk_dict.items(): |
| if len(v) > 0: |
| lmk_dict[k] = np.concatenate(v, axis=0) |
| out_path = dataloader.dataset.get_property_path( |
| "landmark2d/STAR", camera_id=camera_id |
| ) |
| logger.info(f"Saving landmarks to: {out_path}") |
| if not out_path.parent.exists(): |
| out_path.parent.mkdir(parents=True) |
| np.savez(out_path, **lmk_dict) |
|
|
|
|
| if __name__ == "__main__": |
| import tyro |
| from tqdm import tqdm |
| from torch.utils.data import DataLoader |
| from vhap.config.base import DataConfig, import_module |
|
|
| cfg = tyro.cli(DataConfig) |
| dataset = import_module(cfg._target)( |
| cfg=cfg, |
| img_to_tensor=False, |
| batchify_all_views=True, |
| ) |
| dataset.items = dataset.items[:2] |
|
|
| dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4) |
|
|
| detector = LandmarkDetectorSTAR() |
| detector.annotate_landmarks(dataloader) |
|
|