| | import os
|
| | import cv2
|
| | import math
|
| | import argparse
|
| | import numpy as np
|
| | from tqdm import tqdm
|
| |
|
| | import torch
|
| |
|
| |
|
| | from lib import utility
|
| |
|
| |
|
| |
|
| | class GetCropMatrix():
|
| | """
|
| | from_shape -> transform_matrix
|
| | """
|
| |
|
| | def __init__(self, image_size, target_face_scale, align_corners=False):
|
| | self.image_size = image_size
|
| | self.target_face_scale = target_face_scale
|
| | self.align_corners = align_corners
|
| |
|
| | def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
|
| | cosv = math.cos(angle)
|
| | sinv = math.sin(angle)
|
| |
|
| | fx, fy = from_center
|
| | tx, ty = to_center
|
| |
|
| | acos = scale * cosv
|
| | asin = scale * sinv
|
| |
|
| | a0 = acos
|
| | a1 = -asin
|
| | a2 = tx - acos * fx + asin * fy + shift_xy[0]
|
| |
|
| | b0 = asin
|
| | b1 = acos
|
| | b2 = ty - asin * fx - acos * fy + shift_xy[1]
|
| |
|
| | rot_scale_m = np.array([
|
| | [a0, a1, a2],
|
| | [b0, b1, b2],
|
| | [0.0, 0.0, 1.0]
|
| | ], np.float32)
|
| | return rot_scale_m
|
| |
|
| | def process(self, scale, center_w, center_h):
|
| | if self.align_corners:
|
| | to_w, to_h = self.image_size - 1, self.image_size - 1
|
| | else:
|
| | to_w, to_h = self.image_size, self.image_size
|
| |
|
| | rot_mu = 0
|
| | scale_mu = self.image_size / (scale * self.target_face_scale * 200.0)
|
| | shift_xy_mu = (0, 0)
|
| | matrix = self._compose_rotate_and_scale(
|
| | rot_mu, scale_mu, shift_xy_mu,
|
| | from_center=[center_w, center_h],
|
| | to_center=[to_w / 2.0, to_h / 2.0])
|
| | return matrix
|
| |
|
| |
|
| | class TransformPerspective():
|
| | """
|
| | image, matrix3x3 -> transformed_image
|
| | """
|
| |
|
| | def __init__(self, image_size):
|
| | self.image_size = image_size
|
| |
|
| | def process(self, image, matrix):
|
| | return cv2.warpPerspective(
|
| | image, matrix, dsize=(self.image_size, self.image_size),
|
| | flags=cv2.INTER_LINEAR, borderValue=0)
|
| |
|
| |
|
| | class TransformPoints2D():
|
| | """
|
| | points (nx2), matrix (3x3) -> points (nx2)
|
| | """
|
| |
|
| | def process(self, srcPoints, matrix):
|
| |
|
| | desPoints = np.concatenate([srcPoints, np.ones_like(srcPoints[:, [0]])], axis=1)
|
| | desPoints = desPoints @ np.transpose(matrix)
|
| | desPoints = desPoints[:, :2] / desPoints[:, [2, 2]]
|
| | return desPoints.astype(srcPoints.dtype)
|
| |
|
| |
|
| | class Alignment:
|
| | def __init__(self, args, model_path, dl_framework, device_ids):
|
| | self.input_size = 256
|
| | self.target_face_scale = 1.0
|
| | self.dl_framework = dl_framework
|
| |
|
| |
|
| | if self.dl_framework == "pytorch":
|
| |
|
| | self.config = utility.get_config(args)
|
| | self.config.device_id = device_ids[0]
|
| |
|
| | utility.set_environment(self.config)
|
| | self.config.init_instance()
|
| | if self.config.logger is not None:
|
| | self.config.logger.info("Loaded configure file %s: %s" % (args.config_name, self.config.id))
|
| | self.config.logger.info("\n" + "\n".join(["%s: %s" % item for item in self.config.__dict__.items()]))
|
| |
|
| | net = utility.get_net(self.config)
|
| | if device_ids == [-1]:
|
| | checkpoint = torch.load(model_path, map_location="cpu")
|
| | else:
|
| | checkpoint = torch.load(model_path)
|
| | net.load_state_dict(checkpoint["net"])
|
| | net = net.to(self.config.device_id)
|
| | net.eval()
|
| | self.alignment = net
|
| | else:
|
| | assert False
|
| |
|
| | self.getCropMatrix = GetCropMatrix(image_size=self.input_size, target_face_scale=self.target_face_scale,
|
| | align_corners=True)
|
| | self.transformPerspective = TransformPerspective(image_size=self.input_size)
|
| | self.transformPoints2D = TransformPoints2D()
|
| |
|
| | def norm_points(self, points, align_corners=False):
|
| | if align_corners:
|
| |
|
| | return points / torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2) * 2 - 1
|
| | else:
|
| |
|
| | return (points * 2 + 1) / torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1
|
| |
|
| | def denorm_points(self, points, align_corners=False):
|
| | if align_corners:
|
| |
|
| | return (points + 1) / 2 * torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2)
|
| | else:
|
| |
|
| | return ((points + 1) * torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1) / 2
|
| |
|
| | def preprocess(self, image, scale, center_w, center_h):
|
| | matrix = self.getCropMatrix.process(scale, center_w, center_h)
|
| | input_tensor = self.transformPerspective.process(image, matrix)
|
| | input_tensor = input_tensor[np.newaxis, :]
|
| |
|
| | input_tensor = torch.from_numpy(input_tensor)
|
| | input_tensor = input_tensor.float().permute(0, 3, 1, 2)
|
| | input_tensor = input_tensor / 255.0 * 2.0 - 1.0
|
| | input_tensor = input_tensor.to(self.config.device_id)
|
| | return input_tensor, matrix
|
| |
|
| | def postprocess(self, srcPoints, coeff):
|
| |
|
| |
|
| |
|
| | dstPoints = np.zeros(srcPoints.shape, dtype=np.float32)
|
| | for i in range(srcPoints.shape[0]):
|
| | dstPoints[i][0] = coeff[0][0] * srcPoints[i][0] + coeff[0][1] * srcPoints[i][1] + coeff[0][2]
|
| | dstPoints[i][1] = coeff[1][0] * srcPoints[i][0] + coeff[1][1] * srcPoints[i][1] + coeff[1][2]
|
| | return dstPoints
|
| |
|
| | def analyze(self, image, scale, center_w, center_h):
|
| | input_tensor, matrix = self.preprocess(image, scale, center_w, center_h)
|
| |
|
| | if self.dl_framework == "pytorch":
|
| | with torch.no_grad():
|
| | output = self.alignment(input_tensor)
|
| | landmarks = output[-1][0]
|
| | else:
|
| | assert False
|
| |
|
| | landmarks = self.denorm_points(landmarks)
|
| | landmarks = landmarks.data.cpu().numpy()[0]
|
| | landmarks = self.postprocess(landmarks, np.linalg.inv(matrix))
|
| |
|
| | return landmarks
|
| |
|
| |
|
| | def L2(p1, p2):
|
| | return np.linalg.norm(p1 - p2)
|
| |
|
| |
|
| | def NME(landmarks_gt, landmarks_pv):
|
| | pts_num = landmarks_gt.shape[0]
|
| | if pts_num == 29:
|
| | left_index = 16
|
| | right_index = 17
|
| | elif pts_num == 68:
|
| | left_index = 36
|
| | right_index = 45
|
| | elif pts_num == 98:
|
| | left_index = 60
|
| | right_index = 72
|
| |
|
| | nme = 0
|
| | eye_span = L2(landmarks_gt[left_index], landmarks_gt[right_index])
|
| | for i in range(pts_num):
|
| | error = L2(landmarks_pv[i], landmarks_gt[i])
|
| | nme += error / eye_span
|
| | nme /= pts_num
|
| | return nme
|
| |
|
| |
|
| | def evaluate(args, model_path, metadata_path, device_ids, mode):
|
| | alignment = Alignment(args, model_path, dl_framework="pytorch", device_ids=device_ids)
|
| | config = alignment.config
|
| | nme_sum = 0
|
| | with open(metadata_path, 'r') as f:
|
| | lines = f.readlines()
|
| | for k, line in enumerate(tqdm(lines)):
|
| | item = line.strip().split("\t")
|
| | image_name, landmarks_5pts, landmarks_gt, scale, center_w, center_h = item[:6]
|
| |
|
| | image_name = image_name.replace('\\', '/')
|
| | image_name = image_name.replace('//msr-facestore/Workspace/MSRA_EP_Allergan/users/yanghuan/training_data/wflw/rawImages/', '')
|
| | image_name = image_name.replace('./rawImages/', '')
|
| | image_path = os.path.join(config.image_dir, image_name)
|
| | landmarks_gt = np.array(list(map(float, landmarks_gt.split(","))), dtype=np.float32).reshape(-1, 2)
|
| | scale, center_w, center_h = float(scale), float(center_w), float(center_h)
|
| |
|
| | image = cv2.imread(image_path)
|
| | landmarks_pv = alignment.analyze(image, scale, center_w, center_h)
|
| |
|
| |
|
| | if mode == "nme":
|
| | nme = NME(landmarks_gt, landmarks_pv)
|
| | nme_sum += nme
|
| |
|
| | else:
|
| | pass
|
| |
|
| | if mode == "nme":
|
| | print("Final NME: %f" % (100*nme_sum / (k + 1)))
|
| | else:
|
| | pass
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | parser = argparse.ArgumentParser(description="Evaluation script")
|
| | parser.add_argument("--config_name", type=str, default="alignment", help="set configure file name")
|
| | parser.add_argument("--model_path", type=str, default="./train.pkl", help="the path of model")
|
| | parser.add_argument("--data_definition", type=str, default='WFLW', help="COFW/300W/WFLW")
|
| | parser.add_argument("--metadata_path", type=str, default="", help="the path of metadata")
|
| | parser.add_argument("--image_dir", type=str, default="", help="the path of image")
|
| | parser.add_argument("--device_ids", type=str, default="0", help="set device ids, -1 means use cpu device, >= 0 means use gpu device")
|
| | parser.add_argument("--mode", type=str, default="nme", help="set the evaluate mode: nme")
|
| | args = parser.parse_args()
|
| |
|
| | device_ids = list(map(int, args.device_ids.split(",")))
|
| | evaluate(
|
| | args,
|
| | model_path=args.model_path,
|
| | metadata_path=args.metadata_path,
|
| | device_ids=device_ids,
|
| | mode=args.mode)
|
| |
|