| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Optional |
| from torch import Tensor |
|
|
| from .base import Datastruct, dataclass, Transform |
| from ..tools import collate_tensor_with_padding |
|
|
| from .joints2jfeats import Joints2Jfeats |
|
|
|
|
| class XYZTransform(Transform): |
| def __init__(self, joints2jfeats: Joints2Jfeats, **kwargs): |
| self.joints2jfeats = joints2jfeats |
|
|
| def Datastruct(self, **kwargs): |
| return XYZDatastruct(_joints2jfeats=self.joints2jfeats, |
| transforms=self, |
| **kwargs) |
|
|
| def __repr__(self): |
| return "XYZTransform()" |
|
|
|
|
| @dataclass |
| class XYZDatastruct(Datastruct): |
| transforms: XYZTransform |
| _joints2jfeats: Joints2Jfeats |
|
|
| features: Optional[Tensor] = None |
| joints_: Optional[Tensor] = None |
| jfeats_: Optional[Tensor] = None |
|
|
| def __post_init__(self): |
| self.datakeys = ["features", "joints_", "jfeats_"] |
| |
| if self.features is not None and self.jfeats_ is None: |
| self.jfeats_ = self.features |
|
|
| @property |
| def joints(self): |
| |
| if self.joints_ is not None: |
| return self.joints_ |
|
|
| |
| assert self.jfeats_ is not None |
|
|
| self._joints2jfeats.to(self.jfeats.device) |
| self.joints_ = self._joints2jfeats.inverse(self.jfeats) |
| return self.joints_ |
|
|
| @property |
| def jfeats(self): |
| |
| if self.jfeats_ is not None: |
| return self.jfeats_ |
|
|
| |
| assert self.joints_ is not None |
|
|
| self._joints2jfeats.to(self.joints.device) |
| self.jfeats_ = self._joints2jfeats(self.joints) |
| return self.jfeats_ |
|
|
| def __len__(self): |
| return len(self.jfeats) |
|
|