show / SHOW /modules /MICA /api_MICA.py
camenduru's picture
thanks to show ❤
3bbb319
import argparse
import os
import random
from glob import glob
from pathlib import Path
import matplotlib
import platform
if platform.system() == 'Windows':
matplotlib.use('TkAgg')
import sys
sys.path.insert(0,
os.path.join(
os.path.dirname(__file__),'./'
)
)
import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from insightface.app import FaceAnalysis
from insightface.app.common import Face
from insightface.utils import face_align
from loguru import logger
from pytorch3d.io import save_ply
from skimage.io import imread
from tqdm import tqdm
from .micalib.config import get_cfg_defaults
from .datasets.creation.util import get_arcface_input, get_center
from .micalib import util
from typing import Union,NewType,Optional
class api_MICA(object):
def __init__(self):
cfg = get_cfg_defaults()
device = 'cuda:0'
cfg.model.testing = True
self.mica = util.find_model_using_name(model_dir='micalib', model_name=cfg.model.name)(cfg, device)
self.load_checkpoint(self.mica,model_path=cfg.pretrained_model_path)
self.mica.eval()
self.app = FaceAnalysis(name='antelopev2',
providers=['CPUExecutionProvider'],
# providers=['CUDAExecutionProvider']
)
self.app.prepare(ctx_id=0, det_size=(224, 224))
logger.info('MICA api init done.')
def read_arcface(self,input_img_path):
def process(img:Union[np.ndarray,str],
image_size=224
):
if isinstance(img,str):
img = cv2.imread(img)
bboxes, kpss = self.app.det_model.detect(
img,
max_num=0,
metric='default'
)
if bboxes.shape[0] == 0:
# return None, None
logger.error(f'No face detected in {img_path}')
import sys
sys.exit(0)
i = get_center(bboxes, img)
bbox = bboxes[i, 0:4]
det_score = bboxes[i, 4]
kps = None
if kpss is not None:
kps = kpss[i]
face = Face(
bbox=bbox,
kps=kps,
det_score=det_score
)
blob, aimg = get_arcface_input(face, img)
crop_im=face_align.norm_crop(img,
landmark=face.kps,
image_size=image_size)
np.save('tmp.npy', blob)
cv2.imwrite('tmp.jpg', crop_im)
return blob,crop_im
blob,crop_im=process(input_img_path)
if 0:
img_path='tmp.jpg'
image = imread(img_path)[:, :, :3]
else:
image=crop_im
image = image / 255.
image = cv2.resize(image, (224, 224)).transpose(2, 0, 1)
image = torch.tensor(image).cuda()[None]
if 0:
path='tmp.npy'
arcface = np.load(path)
else:
arcface=blob
arcface = torch.tensor(arcface).cuda()[None]
# logger.info('arcface shape: {}'.format(arcface.shape))
logger.info(f'read and process arcface done.')
return image, arcface
def load_checkpoint(self,
mica,
model_path:Union[str,Path],
):
checkpoint = torch.load(model_path)
if 'arcface' in checkpoint:
mica.arcface.load_state_dict(checkpoint['arcface'])
if 'flameModel' in checkpoint:
mica.flameModel.load_state_dict(checkpoint['flameModel'])
def predict(self,
input_img_path,
output_ply_path=None,
output_render_path=None,
output_param_npy_path=None,
):
faces = self.mica.render.faces[0].cpu()
with torch.no_grad():
images, arcface = self.read_arcface(input_img_path)
codedict = self.mica.encode(images, arcface)
opdict = self.mica.decode(codedict)
meshes = opdict['pred_canonical_shape_vertices']
code = opdict['pred_shape_code']
mesh = meshes[0]
rendering = self.mica.render.render_mesh(mesh[None])
image = (rendering[0]
.cpu()
.numpy()
.transpose(1, 2, 0)
.copy()
* 255)[:, :, [2, 1, 0]]
image = np.minimum(
np.maximum(image, 0), 255
).astype(np.uint8)
if output_render_path:
Path(output_render_path).parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(output_render_path, image)
logger.info(f'MICA img: {output_render_path}')
if output_ply_path:
Path(output_ply_path).parent.mkdir(parents=True, exist_ok=True)
save_ply(output_ply_path,
verts=mesh.cpu() * 1000.0,
faces=faces)
logger.info(f'MICA ply: {output_ply_path}')
if output_param_npy_path:
Path(output_param_npy_path).parent.mkdir(parents=True, exist_ok=True)
np.save(output_param_npy_path,
code[0].cpu().numpy())
logger.info(f'MICA npy: {output_param_npy_path}')