| |
| |
| |
| |
| import numpy as np |
| from model_lib import ModelBase |
|
|
| MODEL_ZOO = { |
| 'er8_bs1': { |
| 'model_path': 'pretrain_models/9O_865k.onnx', |
| }, |
| } |
|
|
|
|
| class HifiFace(ModelBase): |
| def __init__(self, model_name='er8_bs1', provider='gpu'): |
| super().__init__(MODEL_ZOO[model_name], provider) |
|
|
| def forward(self, src_face_image, dst_face_latent): |
| """ |
| Args: |
| src_face_image: |
| dst_face_latent: |
| Returns: |
| """ |
| img_tensor = ((src_face_image.transpose(2, 0, 1) / 255.0) * 2 - 1)[None] |
| blob = [img_tensor.astype(np.float32), dst_face_latent.astype(np.float32)] |
| output = self.model.forward(blob) |
| |
| if self.model_type == 'trt': |
| mask, swap_face = output |
| else: |
| swap_face, mask = output |
|
|
| return mask, swap_face |
|
|