SAT-HMR / models /human_models /smpl_models.py
ChiSu001's picture
Upload model files
ff07ed4 verified
import torch
from torch import nn
import smplx
import numpy as np
import pickle
import os.path as osp
from configs.paths import smpl_model_path
class SMPL_Layer(nn.Module):
def __init__(self, model_path, with_genders = True, **kwargs):
"""
Extension of the SMPL Layer with gendered inputs.
"""
super().__init__()
smpl_kwargs = {'create_global_orient': False, 'create_body_pose': False,
'create_betas': False, 'create_transl': False}
smpl_kwargs.update(kwargs)
self.with_genders = with_genders
if self.with_genders:
self.layer_n = smplx.create(model_path, 'smpl', gender='neutral', **smpl_kwargs)
self.layer_m = smplx.create(model_path, 'smpl', gender='male', **smpl_kwargs)
self.layer_f = smplx.create(model_path, 'smpl', gender='female', **smpl_kwargs)
self.layers = {'neutral': self.layer_n, 'male': self.layer_m, 'female': self.layer_f}
else:
self.layer_n = smplx.create(model_path, 'smpl', gender='neutral', **smpl_kwargs)
self.layers = {'neutral': self.layer_n}
self.vertex_num = 6890
self.faces = self.layer_n.faces
self.body_vertex_idx = np.load(osp.join(model_path, 'smpl', 'body_verts_smpl.npy'))
self.smpl2h36m_regressor = np.load(osp.join(model_path, 'smpl', 'J_regressor_h36m_correct.npy'))
def forward_single_gender(self, poses, betas, gender='neutral'):
bs = poses.shape[0]
if poses.ndim == 2:
poses = poses.view(bs, -1, 3)
assert poses.shape[1] == 24
pose_params = {'global_orient': poses[:, :1, :],
'body_pose': poses[:, 1:, :]}
smpl_output = self.layers[gender](betas=betas, **pose_params)
return smpl_output.vertices, smpl_output.joints
def forward(self, poses, betas, genders = None):
bs = poses.shape[0]
assert poses.shape[0] == betas.shape[0]
if genders is None:
return self.forward_single_gender(poses, betas)
else:
assert len(genders) == bs
assert set(genders) <= {'male', 'female'}
assert self.with_genders
male_idx = [i for i, gender in enumerate(genders) if gender == 'male']
if len(male_idx) == bs:
return self.forward_single_gender(poses, betas, gender='male')
elif len(male_idx) == 0:
return self.forward_single_gender(poses, betas, gender='female')
else:
vertices, joints = self.forward_single_gender(poses, betas, gender='female')
vertices[male_idx], joints[male_idx] =\
self.forward_single_gender(poses[male_idx], betas[male_idx], gender='male')
return vertices, joints
smpl_gendered = SMPL_Layer(smpl_model_path, with_genders = True)