Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Copyright (c) Alibaba, Inc. and its affiliates. | |
| import numpy as np | |
| import torch | |
| from .utils import convert_to_numpy | |
| class FaceAnnotator: | |
| def __init__(self, cfg, device=None): | |
| from insightface.app import FaceAnalysis | |
| self.return_raw = cfg.get('RETURN_RAW', True) | |
| self.return_mask = cfg.get('RETURN_MASK', False) | |
| self.return_dict = cfg.get('RETURN_DICT', False) | |
| self.multi_face = cfg.get('MULTI_FACE', True) | |
| pretrained_model = cfg['PRETRAINED_MODEL'] | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device | |
| self.device_id = self.device.index if self.device.type == 'cuda' else None | |
| ctx_id = self.device_id if self.device_id is not None else 0 | |
| self.model = FaceAnalysis(name=cfg.MODEL_NAME, root=pretrained_model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) | |
| self.model.prepare(ctx_id=ctx_id, det_size=(640, 640)) | |
| def forward(self, image=None, return_mask=None, return_dict=None): | |
| return_mask = return_mask if return_mask is not None else self.return_mask | |
| return_dict = return_dict if return_dict is not None else self.return_dict | |
| image = convert_to_numpy(image) | |
| # [dict_keys(['bbox', 'kps', 'det_score', 'landmark_3d_68', 'pose', 'landmark_2d_106', 'gender', 'age', 'embedding'])] | |
| faces = self.model.get(image) | |
| if self.return_raw: | |
| return faces | |
| else: | |
| crop_face_list, mask_list = [], [] | |
| if len(faces) > 0: | |
| if not self.multi_face: | |
| faces = faces[:1] | |
| for face in faces: | |
| x_min, y_min, x_max, y_max = face['bbox'].tolist() | |
| crop_face = image[int(y_min): int(y_max) + 1, int(x_min): int(x_max) + 1] | |
| crop_face_list.append(crop_face) | |
| mask = np.zeros_like(image[:, :, 0]) | |
| mask[int(y_min): int(y_max) + 1, int(x_min): int(x_max) + 1] = 255 | |
| mask_list.append(mask) | |
| if not self.multi_face: | |
| crop_face_list = crop_face_list[0] | |
| mask_list = mask_list[0] | |
| if return_mask: | |
| if return_dict: | |
| return {'image': crop_face_list, 'mask': mask_list} | |
| else: | |
| return crop_face_list, mask_list | |
| else: | |
| return crop_face_list | |
| else: | |
| return None | |