| import torch |
| from utils.transforms import quat2mat, repr6d2mat, euler2mat |
|
|
|
|
| class ForwardKinematics: |
| def __init__(self, parents, offsets=None): |
| self.parents = parents |
| if offsets is not None and len(offsets.shape) == 2: |
| offsets = offsets.unsqueeze(0) |
| self.offsets = offsets |
|
|
| def forward(self, rots, offsets=None, global_pos=None): |
| """ |
| Forward Kinematics: returns a per-bone transformation |
| @param rots: local joint rotations (batch_size, bone_num, 3, 3) |
| @param offsets: (batch_size, bone_num, 3) or None |
| @param global_pos: global_position: (batch_size, 3) or keep it as in offsets (default) |
| @return: (batch_szie, bone_num, 3, 4) |
| """ |
| rots = rots.clone() |
| if offsets is None: |
| offsets = self.offsets.to(rots.device) |
| if global_pos is None: |
| global_pos = offsets[:, 0] |
|
|
| pos = torch.zeros((rots.shape[0], rots.shape[1], 3), device=rots.device) |
| rest_pos = torch.zeros_like(pos) |
| res = torch.zeros((rots.shape[0], rots.shape[1], 3, 4), device=rots.device) |
|
|
| pos[:, 0] = global_pos |
| rest_pos[:, 0] = offsets[:, 0] |
|
|
| for i, p in enumerate(self.parents): |
| if i != 0: |
| rots[:, i] = torch.matmul(rots[:, p], rots[:, i]) |
| pos[:, i] = torch.matmul(rots[:, p], offsets[:, i].unsqueeze(-1)).squeeze(-1) + pos[:, p] |
| rest_pos[:, i] = rest_pos[:, p] + offsets[:, i] |
|
|
| res[:, i, :3, :3] = rots[:, i] |
| res[:, i, :, 3] = torch.matmul(rots[:, i], -rest_pos[:, i].unsqueeze(-1)).squeeze(-1) + pos[:, i] |
|
|
| return res |
|
|
| def accumulate(self, local_rots): |
| """ |
| Get global joint rotation from local rotations |
| @param local_rots: (batch_size, n_bone, 3, 3) |
| @return: global_rotations |
| """ |
| res = torch.empty_like(local_rots) |
| for i, p in enumerate(self.parents): |
| if i == 0: |
| res[:, i] = local_rots[:, i] |
| else: |
| res[:, i] = torch.matmul(res[:, p], local_rots[:, i]) |
| return res |
|
|
| def unaccumulate(self, global_rots): |
| """ |
| Get local joint rotation from global rotations |
| @param global_rots: (batch_size, n_bone, 3, 3) |
| @return: local_rotations |
| """ |
| res = torch.empty_like(global_rots) |
| inv = torch.empty_like(global_rots) |
|
|
| for i, p in enumerate(self.parents): |
| if i == 0: |
| inv[:, i] = global_rots[:, i].transpose(-2, -1) |
| res[:, i] = global_rots[:, i] |
| continue |
| res[:, i] = torch.matmul(inv[:, p], global_rots[:, i]) |
| inv[:, i] = torch.matmul(res[:, i].transpose(-2, -1), inv[:, p]) |
|
|
| return res |
|
|
|
|
| class ForwardKinematicsJoint: |
| def __init__(self, parents, offset): |
| self.parents = parents |
| self.offset = offset |
|
|
| ''' |
| rotation should have shape batch_size * Joint_num * (3/4) * Time |
| position should have shape batch_size * 3 * Time |
| offset should have shape batch_size * Joint_num * 3 |
| output have shape batch_size * Time * Joint_num * 3 |
| ''' |
|
|
| def forward(self, rotation: torch.Tensor, position: torch.Tensor, offset=None, |
| world=True): |
| ''' |
| if not quater and rotation.shape[-2] != 3: raise Exception('Unexpected shape of rotation') |
| if quater and rotation.shape[-2] != 4: raise Exception('Unexpected shape of rotation') |
| rotation = rotation.permute(0, 3, 1, 2) |
| position = position.permute(0, 2, 1) |
| ''' |
| if rotation.shape[-1] == 6: |
| transform = repr6d2mat(rotation) |
| elif rotation.shape[-1] == 4: |
| norm = torch.norm(rotation, dim=-1, keepdim=True) |
| rotation = rotation / norm |
| transform = quat2mat(rotation) |
| elif rotation.shape[-1] == 3: |
| transform = euler2mat(rotation) |
| else: |
| raise Exception('Only accept quaternion rotation input') |
| result = torch.empty(transform.shape[:-2] + (3,), device=position.device) |
|
|
| if offset is None: |
| offset = self.offset |
| offset = offset.reshape((-1, 1, offset.shape[-2], offset.shape[-1], 1)) |
|
|
| result[..., 0, :] = position |
| for i, pi in enumerate(self.parents): |
| if pi == -1: |
| assert i == 0 |
| continue |
|
|
| result[..., i, :] = torch.matmul(transform[..., pi, :, :], offset[..., i, :, :]).squeeze() |
| transform[..., i, :, :] = torch.matmul(transform[..., pi, :, :].clone(), transform[..., i, :, :].clone()) |
| if world: result[..., i, :] += result[..., pi, :] |
| return result |
|
|
|
|
| class InverseKinematicsJoint: |
| def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, offset, parents, constrains): |
| self.rotations = rotations.detach().clone() |
| self.rotations.requires_grad_(True) |
| self.position = positions.detach().clone() |
| self.position.requires_grad_(True) |
|
|
| self.parents = parents |
| self.offset = offset |
| self.constrains = constrains |
|
|
| self.optimizer = torch.optim.Adam([self.position, self.rotations], lr=1e-3, betas=(0.9, 0.999)) |
| self.criteria = torch.nn.MSELoss() |
|
|
| self.fk = ForwardKinematicsJoint(parents, offset) |
|
|
| self.glb = None |
|
|
| def step(self): |
| self.optimizer.zero_grad() |
| glb = self.fk.forward(self.rotations, self.position) |
| loss = self.criteria(glb, self.constrains) |
| loss.backward() |
| self.optimizer.step() |
| self.glb = glb |
| return loss.item() |
|
|
|
|
| class InverseKinematicsJoint2: |
| def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, offset, parents, constrains, cid, |
| lambda_rec_rot=1., lambda_rec_pos=1., use_velo=False): |
| self.use_velo = use_velo |
| self.rotations_ori = rotations.detach().clone() |
| self.rotations = rotations.detach().clone() |
| self.rotations.requires_grad_(True) |
| self.position_ori = positions.detach().clone() |
| self.position = positions.detach().clone() |
| if self.use_velo: |
| self.position[1:] = self.position[1:] - self.position[:-1] |
| self.position.requires_grad_(True) |
|
|
| self.parents = parents |
| self.offset = offset |
| self.constrains = constrains.detach().clone() |
| self.cid = cid |
|
|
| self.lambda_rec_rot = lambda_rec_rot |
| self.lambda_rec_pos = lambda_rec_pos |
|
|
| self.optimizer = torch.optim.Adam([self.position, self.rotations], lr=1e-3, betas=(0.9, 0.999)) |
| self.criteria = torch.nn.MSELoss() |
|
|
| self.fk = ForwardKinematicsJoint(parents, offset) |
|
|
| self.glb = None |
|
|
| def step(self): |
| self.optimizer.zero_grad() |
| if self.use_velo: |
| position = torch.cumsum(self.position, dim=0) |
| else: |
| position = self.position |
| glb = self.fk.forward(self.rotations, position) |
| self.constrain_loss = self.criteria(glb[:, self.cid], self.constrains) |
| self.rec_loss_rot = self.criteria(self.rotations, self.rotations_ori) |
| self.rec_loss_pos = self.criteria(self.position, self.position_ori) |
| loss = self.constrain_loss + self.rec_loss_rot * self.lambda_rec_rot + self.rec_loss_pos * self.lambda_rec_pos |
| loss.backward() |
| self.optimizer.step() |
| self.glb = glb |
| return loss.item() |
|
|
| def get_position(self): |
| if self.use_velo: |
| position = torch.cumsum(self.position.detach(), dim=0) |
| else: |
| position = self.position.detach() |
| return position |
|
|