| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Optional |
| | from torch import Tensor |
| | import smplx |
| |
|
| | from .base import Datastruct, dataclass, Transform |
| |
|
| | from .rots2rfeats import Rots2Rfeats |
| | from .rots2joints import Rots2Joints |
| | from .joints2jfeats import Joints2Jfeats |
| |
|
| |
|
| | class SMPLTransform(Transform): |
| | def __init__(self, rots2rfeats: Rots2Rfeats, |
| | rots2joints: Rots2Joints, |
| | joints2jfeats: Joints2Jfeats, |
| | **kwargs): |
| | self.rots2rfeats = rots2rfeats |
| | self.rots2joints = rots2joints |
| | self.joints2jfeats = joints2jfeats |
| |
|
| | def Datastruct(self, **kwargs): |
| | return SMPLDatastruct(_rots2rfeats=self.rots2rfeats, |
| | _rots2joints=self.rots2joints, |
| | _joints2jfeats=self.joints2jfeats, |
| | transforms=self, |
| | **kwargs) |
| |
|
| | def __repr__(self): |
| | return "SMPLTransform()" |
| |
|
| |
|
| | class RotIdentityTransform(Transform): |
| | def __init__(self, **kwargs): |
| | return |
| |
|
| | def Datastruct(self, **kwargs): |
| | return RotTransDatastruct(**kwargs) |
| |
|
| | def __repr__(self): |
| | return "RotIdentityTransform()" |
| |
|
| |
|
| | @dataclass |
| | class RotTransDatastruct(Datastruct): |
| | rots: Tensor |
| | trans: Tensor |
| |
|
| | transforms: RotIdentityTransform = RotIdentityTransform() |
| |
|
| | def __post_init__(self): |
| | self.datakeys = ["rots", "trans"] |
| |
|
| | def __len__(self): |
| | return len(self.rots) |
| |
|
| |
|
| | @dataclass |
| | class SMPLDatastruct(Datastruct): |
| | transforms: SMPLTransform |
| | _rots2rfeats: Rots2Rfeats |
| | _rots2joints: Rots2Joints |
| | _joints2jfeats: Joints2Jfeats |
| |
|
| | features: Optional[Tensor] = None |
| | rots_: Optional[RotTransDatastruct] = None |
| | rfeats_: Optional[Tensor] = None |
| | joints_: Optional[Tensor] = None |
| | jfeats_: Optional[Tensor] = None |
| | vertices_: Optional[Tensor] = None |
| |
|
| | def __post_init__(self): |
| | self.datakeys = ['features', 'rots_', 'rfeats_', |
| | 'joints_', 'jfeats_', 'vertices_'] |
| | |
| | if self.features is not None and self.rfeats_ is None: |
| | self.rfeats_ = self.features |
| |
|
| | @property |
| | def rots(self): |
| | |
| | if self.rots_ is not None: |
| | return self.rots_ |
| |
|
| | |
| | assert self.rfeats_ is not None |
| |
|
| | self._rots2rfeats.to(self.rfeats.device) |
| | self.rots_ = self._rots2rfeats.inverse(self.rfeats) |
| | return self.rots_ |
| |
|
| | @property |
| | def rfeats(self): |
| | |
| | if self.rfeats_ is not None: |
| | return self.rfeats_ |
| |
|
| | |
| | assert self.rots_ is not None |
| |
|
| | self._rots2rfeats.to(self.rots.device) |
| | self.rfeats_ = self._rots2rfeats(self.rots) |
| | return self.rfeats_ |
| |
|
| | @property |
| | def joints(self): |
| | |
| | if self.joints_ is not None: |
| | return self.joints_ |
| |
|
| | self._rots2joints.to(self.rots.device) |
| | self.joints_ = self._rots2joints(self.rots) |
| | return self.joints_ |
| |
|
| | @property |
| | def jfeats(self): |
| | |
| | if self.jfeats_ is not None: |
| | return self.jfeats_ |
| |
|
| | self._joints2jfeats.to(self.joints.device) |
| | self.jfeats_ = self._joints2jfeats(self.joints) |
| | return self.jfeats_ |
| | |
| | @property |
| | def vertices(self): |
| | |
| | if self.vertices_ is not None: |
| | return self.vertices_ |
| |
|
| | self._rots2joints.to(self.rots.device) |
| | self.vertices_ = self._rots2joints(self.rots, jointstype="vertices") |
| | return self.vertices_ |
| | |
| | def __len__(self): |
| | return len(self.rfeats) |
| |
|
| |
|
| | def get_body_model(model_type, gender, batch_size, device='cpu', ext='pkl'): |
| | ''' |
| | type: smpl, smplx smplh and others. Refer to smplx tutorial |
| | gender: male, female, neutral |
| | batch_size: an positive integar |
| | ''' |
| | mtype = model_type.upper() |
| | if gender != 'neutral': |
| | if not isinstance(gender, str): |
| | gender = str(gender.astype(str)).upper() |
| | else: |
| | gender = gender.upper() |
| | else: |
| | gender = gender.upper() |
| | ext = 'npz' |
| | body_model_path = f'data/smpl_models/{model_type}/{mtype}_{gender}.{ext}' |
| |
|
| | body_model = smplx.create(body_model_path, model_type=type, |
| | gender=gender, ext=ext, |
| | use_pca=False, |
| | num_pca_comps=12, |
| | create_global_orient=True, |
| | create_body_pose=True, |
| | create_betas=True, |
| | create_left_hand_pose=True, |
| | create_right_hand_pose=True, |
| | create_expression=True, |
| | create_jaw_pose=True, |
| | create_leye_pose=True, |
| | create_reye_pose=True, |
| | create_transl=True, |
| | batch_size=batch_size) |
| | |
| | if device == 'cuda': |
| | return body_model.cuda() |
| | else: |
| | return body_model |
| |
|
| |
|