import xml.etree.ElementTree as ET import numpy as np import torch from . import torch_utils class Joint: def __init__(self, name, dof_dim, axis): self._name = name self._dof_dim = dof_dim self._axis = axis self._dof_idx = -1 # indicate the start index of dof in the whole dof vector, -1 for root or no dof joint def set_dof_idx(self, dof_idx): if self._dof_dim == 0: raise ValueError('Joint {} has no dof'.format(self._name)) self._dof_idx = dof_idx def dof_to_rot(self, dof): # Input dof shape: [..., dof_dim] # Output rot shape: [..., 4] # Function: convert 1-dim or 3-dim dof to quaternion rot_shape = list(dof.shape[:-1]) + [4] ret_rot = torch.zeros(rot_shape, dtype=dof.dtype, device=dof.device) if self._dof_dim == 0: ret_rot[..., -1] = 1.0 elif self._dof_dim == 1: axis = self._axis # shape: [3] axis = torch.broadcast_to(axis, ret_rot[..., 0:3].shape) ret_rot[:] = torch_utils.axis_angle_to_quat(axis, dof.squeeze(-1)) elif self._dof_dim == 3: ret_rot[:] = torch_utils.exp_map_to_quat(dof) return ret_rot def rot_to_dof(self, rot): # Input rot shape: [..., 4] # Output dof shape: [..., dof_dim] # Function: convert quaternion to 1-dim or 3-dim dof dof_shape = list(rot.shape[:-1]) + [self._dof_dim] ret_dof = torch.zeros(dof_shape, dtype=rot.dtype, device=rot.device) if self._dof_dim == 1: axis = self._axis axis, angle = torch_utils.quat_to_axis_angle(rot) dot_axis = torch.sum(axis * self._axis, dim=-1) angle[dot_axis < 0] *= -1 ret_dof[:] = angle.unsqueeze(-1) elif self._dof_dim == 3: ret_dof[:] = torch_utils.quat_to_exp_map(rot) return ret_dof @property def dof_dim(self): return self._dof_dim @property def name(self): return self._name @property def dof_idx(self): return self._dof_idx class KinematicsModel: def __init__(self, file_path, device): self._device = device self._file_path = file_path self._build_kinematics_model() self._set_dof_indices() def _build_kinematics_model(self): self._body_names = [] self._parent_indices = [] self._local_translation = [] self._local_rotation = [] self._joints = [] self._dof_size = [] self._dof_upper_limits = [] self._dof_lower_limits = [] if self._file_path.endswith('.xml'): self._parse_xml() else: raise NotImplementedError('File type not supported') self._parent_indices = torch.tensor(self._parent_indices, dtype=torch.long, device=self._device) self._local_translation = torch.tensor(np.array(self._local_translation), dtype=torch.float, device=self._device) self._local_rotation = torch.tensor(np.array(self._local_rotation), dtype=torch.float, device=self._device) self._num_dof = sum(self._dof_size) self._dof_lower_limits = torch.tensor(self._dof_lower_limits, dtype=torch.float, device=self._device) self._dof_upper_limits = torch.tensor(self._dof_upper_limits, dtype=torch.float, device=self._device) if self._rot_unit == "degree": self._dof_lower_limits = torch.deg2rad(self._dof_lower_limits) self._dof_upper_limits = torch.deg2rad(self._dof_upper_limits) def _parse_xml(self): tree = ET.parse(self._file_path) xml_doc_root = tree.getroot() xml_world_body = xml_doc_root.find("worldbody") assert xml_world_body is not None, "worldbody not found" xml_body_root = xml_world_body.find("body") assert xml_body_root is not None, "body not found" compiler_data = xml_doc_root.find("compiler") self._rot_unit = compiler_data.attrib.get("angle", "degree") assert self._rot_unit in ["degree", "radian"], f"Invalid rotation unit: {self._rot_unit}" def _add_xml_body(xml_node, parent_index, body_index): body_name = xml_node.attrib.get("name") pos_data = xml_node.attrib.get("pos", "0 0 0") pos = np.fromstring(pos_data, dtype=float, sep=" ") rot_data = xml_node.attrib.get("quat", "1 0 0 0") rot = np.fromstring(rot_data, dtype=float, sep=" ") rot_w = rot[..., 0].copy() rot[..., 0:3] = rot[..., 1:] rot[..., 3] = rot_w if body_index == 0: curr_joint = Joint(name=body_name, dof_dim=0, axis=None) # root else: curr_joints = xml_node.findall("joint") num_joints = len(curr_joints) if num_joints == 0: curr_joint = Joint(name=body_name, dof_dim=0, axis=None) elif num_joints == 1: _axis = np.fromstring(curr_joints[0].attrib.get("axis"), dtype=float, sep=" ") axis = torch.from_numpy(_axis).to(self._device) curr_joint = Joint(name=body_name, dof_dim=1, axis=axis) _dof_limits = np.fromstring(curr_joints[0].attrib.get("range"), dtype=float, sep=" ") self._dof_lower_limits.append(_dof_limits[0]) self._dof_upper_limits.append(_dof_limits[1]) elif num_joints == 3: axis = None curr_joint = Joint(name=body_name, dof_dim=3, axis=axis) for joint in curr_joints: _dof_limits = np.fromstring(joint.attrib.get("range"), dtype=float, sep=" ") self._dof_lower_limits.append(_dof_limits[0]) self._dof_upper_limits.append(_dof_limits[1]) else: raise ValueError(f"Invalid number of joints: {num_joints} of body: {body_name}") self._body_names.append(body_name) self._parent_indices.append(parent_index) self._local_rotation.append(rot) self._local_translation.append(pos) self._joints.append(curr_joint) self._dof_size.append(curr_joint.dof_dim) curr_index = body_index body_index += 1 for child in xml_node.findall("body"): body_index = _add_xml_body(child, curr_index, body_index) return body_index _add_xml_body(xml_body_root, -1, 0) def _set_dof_indices(self): curr_dof_idx = 0 for joint in self._joints: if joint.dof_dim > 0: joint.set_dof_idx(curr_dof_idx) curr_dof_idx += joint.dof_dim def dof_to_rot(self, dof): rot_shape = list(dof.shape[:-1]) + [self.num_joint-1, 4] joint_rot = torch.zeros(rot_shape, dtype=dof.dtype, device=dof.device) for j in range(1, self.num_joint): joint = self._joints[j] if joint.dof_idx == -1: joint_rot[..., j-1, -1] = 1.0 else: joint_rot[..., j-1, :] = joint.dof_to_rot(dof[..., joint.dof_idx:joint.dof_idx+joint.dof_dim]) return joint_rot def rot_to_dof(self, rot): dof_shape = list(rot.shape[:-2]) + [self.num_dof] dof = torch.zeros(dof_shape, dtype=rot.dtype, device=rot.device) for j in range(1, self.num_joint): joint = self._joints[j] if joint.dof_dim == 0: continue joint_rot = rot[..., j-1, :] dof[..., joint.dof_idx:joint.dof_idx+joint.dof_dim] = joint.rot_to_dof(joint_rot) dof = torch.clamp(dof, self._dof_lower_limits, self._dof_upper_limits) return dof def convert_local_rot_to_global(self, local_rot): # Input local_rot shape: [..., num_joint, 4] first row is root rotation # local rotation shape: [num_joint-1, 4] global_rot = torch.zeros_like(local_rot) global_rot[..., 0, :] = local_rot[..., 0, :] for j in range(1, self.num_joint): parent_idx = self._parent_indices[j] parent_rot = global_rot[..., parent_idx, :] local_rot_j = local_rot[..., j, :] global_rot[..., j, :] = torch_utils.quat_mul(parent_rot, local_rot_j) return global_rot def forward_kinematics(self, root_pos, root_rot, dof_pos, fitted_shape=None): joint_rot = self.dof_to_rot(dof_pos) body_pos = [None] * self.num_joint body_rot = [None] * self.num_joint body_pos[0] = root_pos body_rot[0] = root_rot for j in range(1, self.num_joint): j_rot = joint_rot[..., j-1, :] local_trans = self._local_translation[j] if fitted_shape is None else self._local_translation[j] * fitted_shape[j] local_rot = self._local_rotation[j] parent_idx = self._parent_indices[j] parent_pos = body_pos[parent_idx] parent_rot = body_rot[parent_idx] local_trans_broadcast = torch.broadcast_to(local_trans, parent_pos.shape) local_rot_broadcast = torch.broadcast_to(local_rot, parent_rot.shape) world_trans = torch_utils.quat_rotate(parent_rot, local_trans_broadcast) curr_pos = parent_pos + world_trans curr_rot = torch_utils.quat_mul(local_rot_broadcast, j_rot) curr_rot = torch_utils.quat_mul(parent_rot, curr_rot) body_pos[j] = curr_pos body_rot[j] = curr_rot body_pos = torch.stack(body_pos, dim=-2) body_rot = torch.stack(body_rot, dim=-2) return body_pos, body_rot def get_body_idx(self, body_name): return self._body_names.index(body_name) @property def body_names(self): return self._body_names @property def num_dof(self): return self._num_dof @property def num_joint(self): return len(self._joints) @property def joint_dof_idx(self): dof_indices = [] for joint in self._joints: dof_indices.append(joint.dof_idx) return dof_indices @property def parent_indices(self): return self._parent_indices def get_parent_idx(self, idx): return self._parent_indices[idx] def get_dof_limits(self): return self._dof_lower_limits, self._dof_upper_limits