|
|
import argparse |
|
|
import os |
|
|
import random |
|
|
from glob import glob |
|
|
from pathlib import Path |
|
|
|
|
|
import matplotlib |
|
|
import platform |
|
|
if platform.system() == 'Windows': |
|
|
matplotlib.use('TkAgg') |
|
|
|
|
|
import sys |
|
|
sys.path.insert(0, |
|
|
os.path.join( |
|
|
os.path.dirname(__file__),'./' |
|
|
) |
|
|
) |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.backends.cudnn as cudnn |
|
|
from insightface.app import FaceAnalysis |
|
|
from insightface.app.common import Face |
|
|
from insightface.utils import face_align |
|
|
from loguru import logger |
|
|
from pytorch3d.io import save_ply |
|
|
from skimage.io import imread |
|
|
from tqdm import tqdm |
|
|
|
|
|
from .micalib.config import get_cfg_defaults |
|
|
from .datasets.creation.util import get_arcface_input, get_center |
|
|
from .micalib import util |
|
|
|
|
|
from typing import Union,NewType,Optional |
|
|
|
|
|
class api_MICA(object): |
|
|
def __init__(self): |
|
|
cfg = get_cfg_defaults() |
|
|
device = 'cuda:0' |
|
|
cfg.model.testing = True |
|
|
self.mica = util.find_model_using_name(model_dir='micalib', model_name=cfg.model.name)(cfg, device) |
|
|
self.load_checkpoint(self.mica,model_path=cfg.pretrained_model_path) |
|
|
self.mica.eval() |
|
|
|
|
|
self.app = FaceAnalysis(name='antelopev2', |
|
|
providers=['CPUExecutionProvider'], |
|
|
|
|
|
) |
|
|
self.app.prepare(ctx_id=0, det_size=(224, 224)) |
|
|
logger.info('MICA api init done.') |
|
|
|
|
|
def read_arcface(self,input_img_path): |
|
|
def process(img:Union[np.ndarray,str], |
|
|
image_size=224 |
|
|
): |
|
|
if isinstance(img,str): |
|
|
img = cv2.imread(img) |
|
|
bboxes, kpss = self.app.det_model.detect( |
|
|
img, |
|
|
max_num=0, |
|
|
metric='default' |
|
|
) |
|
|
|
|
|
if bboxes.shape[0] == 0: |
|
|
|
|
|
logger.error(f'No face detected in {img_path}') |
|
|
import sys |
|
|
sys.exit(0) |
|
|
|
|
|
i = get_center(bboxes, img) |
|
|
bbox = bboxes[i, 0:4] |
|
|
det_score = bboxes[i, 4] |
|
|
kps = None |
|
|
if kpss is not None: |
|
|
kps = kpss[i] |
|
|
face = Face( |
|
|
bbox=bbox, |
|
|
kps=kps, |
|
|
det_score=det_score |
|
|
) |
|
|
blob, aimg = get_arcface_input(face, img) |
|
|
crop_im=face_align.norm_crop(img, |
|
|
landmark=face.kps, |
|
|
image_size=image_size) |
|
|
np.save('tmp.npy', blob) |
|
|
cv2.imwrite('tmp.jpg', crop_im) |
|
|
return blob,crop_im |
|
|
|
|
|
blob,crop_im=process(input_img_path) |
|
|
|
|
|
if 0: |
|
|
img_path='tmp.jpg' |
|
|
image = imread(img_path)[:, :, :3] |
|
|
else: |
|
|
image=crop_im |
|
|
|
|
|
image = image / 255. |
|
|
image = cv2.resize(image, (224, 224)).transpose(2, 0, 1) |
|
|
image = torch.tensor(image).cuda()[None] |
|
|
|
|
|
if 0: |
|
|
path='tmp.npy' |
|
|
arcface = np.load(path) |
|
|
else: |
|
|
arcface=blob |
|
|
|
|
|
arcface = torch.tensor(arcface).cuda()[None] |
|
|
|
|
|
logger.info(f'read and process arcface done.') |
|
|
return image, arcface |
|
|
|
|
|
def load_checkpoint(self, |
|
|
mica, |
|
|
model_path:Union[str,Path], |
|
|
): |
|
|
checkpoint = torch.load(model_path) |
|
|
if 'arcface' in checkpoint: |
|
|
mica.arcface.load_state_dict(checkpoint['arcface']) |
|
|
if 'flameModel' in checkpoint: |
|
|
mica.flameModel.load_state_dict(checkpoint['flameModel']) |
|
|
|
|
|
def predict(self, |
|
|
input_img_path, |
|
|
output_ply_path=None, |
|
|
output_render_path=None, |
|
|
output_param_npy_path=None, |
|
|
): |
|
|
faces = self.mica.render.faces[0].cpu() |
|
|
with torch.no_grad(): |
|
|
images, arcface = self.read_arcface(input_img_path) |
|
|
|
|
|
codedict = self.mica.encode(images, arcface) |
|
|
opdict = self.mica.decode(codedict) |
|
|
meshes = opdict['pred_canonical_shape_vertices'] |
|
|
code = opdict['pred_shape_code'] |
|
|
mesh = meshes[0] |
|
|
|
|
|
rendering = self.mica.render.render_mesh(mesh[None]) |
|
|
image = (rendering[0] |
|
|
.cpu() |
|
|
.numpy() |
|
|
.transpose(1, 2, 0) |
|
|
.copy() |
|
|
* 255)[:, :, [2, 1, 0]] |
|
|
|
|
|
image = np.minimum( |
|
|
np.maximum(image, 0), 255 |
|
|
).astype(np.uint8) |
|
|
|
|
|
if output_render_path: |
|
|
Path(output_render_path).parent.mkdir(parents=True, exist_ok=True) |
|
|
cv2.imwrite(output_render_path, image) |
|
|
logger.info(f'MICA img: {output_render_path}') |
|
|
|
|
|
if output_ply_path: |
|
|
Path(output_ply_path).parent.mkdir(parents=True, exist_ok=True) |
|
|
save_ply(output_ply_path, |
|
|
verts=mesh.cpu() * 1000.0, |
|
|
faces=faces) |
|
|
logger.info(f'MICA ply: {output_ply_path}') |
|
|
|
|
|
if output_param_npy_path: |
|
|
Path(output_param_npy_path).parent.mkdir(parents=True, exist_ok=True) |
|
|
np.save(output_param_npy_path, |
|
|
code[0].cpu().numpy()) |
|
|
logger.info(f'MICA npy: {output_param_npy_path}') |
|
|
|
|
|
|