|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pdb
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
from .base_model import BaseModel
|
|
|
import torch
|
|
|
from torch.cuda import nvtx
|
|
|
from .predictor import numpy_to_torch_dtype_dict
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
def headpose_pred_to_degree(pred):
|
|
|
"""
|
|
|
pred: (bs, 66) or (bs, 1) or others
|
|
|
"""
|
|
|
if pred.ndim > 1 and pred.shape[1] == 66:
|
|
|
|
|
|
idx_array = np.arange(0, 66)
|
|
|
pred = np.apply_along_axis(lambda x: np.exp(x) / np.sum(np.exp(x)), 1, pred)
|
|
|
degree = np.sum(pred * idx_array, axis=1) * 3 - 97.5
|
|
|
|
|
|
return degree
|
|
|
|
|
|
return pred
|
|
|
|
|
|
|
|
|
class MotionExtractorModel(BaseModel):
|
|
|
"""
|
|
|
MotionExtractorModel
|
|
|
"""
|
|
|
|
|
|
def __init__(self, **kwargs):
|
|
|
super(MotionExtractorModel, self).__init__(**kwargs)
|
|
|
self.flag_refine_info = kwargs.get("flag_refine_info", True)
|
|
|
|
|
|
def input_process(self, *data):
|
|
|
img = data[0].astype(np.float32)
|
|
|
img /= 255.0
|
|
|
img = np.transpose(img, (2, 0, 1))
|
|
|
return img[None]
|
|
|
|
|
|
def output_process(self, *data):
|
|
|
if self.predict_type == "trt":
|
|
|
kp, pitch, yaw, roll, t, exp, scale = data
|
|
|
else:
|
|
|
pitch, yaw, roll, t, exp, scale, kp = data
|
|
|
if self.flag_refine_info:
|
|
|
bs = kp.shape[0]
|
|
|
pitch = headpose_pred_to_degree(pitch)[:, None]
|
|
|
yaw = headpose_pred_to_degree(yaw)[:, None]
|
|
|
roll = headpose_pred_to_degree(roll)[:, None]
|
|
|
kp = kp.reshape(bs, -1, 3)
|
|
|
exp = exp.reshape(bs, -1, 3)
|
|
|
return pitch, yaw, roll, t, exp, scale, kp
|
|
|
|
|
|
def predict_trt(self, *data):
|
|
|
nvtx.range_push("forward")
|
|
|
feed_dict = {}
|
|
|
for i, inp in enumerate(self.predictor.inputs):
|
|
|
if isinstance(data[i], torch.Tensor):
|
|
|
feed_dict[inp['name']] = data[i]
|
|
|
else:
|
|
|
feed_dict[inp['name']] = torch.from_numpy(data[i]).to(device=self.device,
|
|
|
dtype=numpy_to_torch_dtype_dict[inp['dtype']])
|
|
|
preds_dict = self.predictor.predict(feed_dict, self.cudaStream)
|
|
|
outs = []
|
|
|
for i, out in enumerate(self.predictor.outputs):
|
|
|
outs.append(preds_dict[out["name"]].cpu().numpy())
|
|
|
nvtx.range_pop()
|
|
|
return outs
|
|
|
|
|
|
def predict(self, *data):
|
|
|
|
|
|
img = data[0]
|
|
|
if self.predict_type == "trt":
|
|
|
preds = self.predict_trt(img)
|
|
|
else:
|
|
|
preds = self.predictor.predict(img)
|
|
|
outputs = self.output_process(*preds)
|
|
|
return outputs
|
|
|
|