Spaces:
Running on Zero
Running on Zero
| import numpy as np | |
| from typing import List, Tuple | |
| def axis_point_to_plucker(axis: np.ndarray, point: np.ndarray) -> np.ndarray: | |
| """ | |
| Convert axis-point coordinates to plucker coordinates. | |
| """ | |
| assert axis.shape[-1] == 3 | |
| assert point.shape[-1] == 3 | |
| l = axis / (np.linalg.norm(axis, axis=-1, keepdims=True) + 1e-8) | |
| m = np.cross(l, point, axis=-1) | |
| return np.concatenate([l, m], axis=-1) | |
| def plucker_to_axis_point(plucker: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Convert plucker coordinates to axis-point coordinates. | |
| """ | |
| assert plucker.shape[-1] == 6 | |
| l, m = plucker[..., :3], plucker[..., 3:] | |
| axis = l / (np.linalg.norm(l, axis=-1, keepdims=True) + 1e-8) | |
| point = np.cross(m, axis, axis=-1) | |
| return axis, point | |
| def plucker_to_4x4_transform_matrix(plucker: np.ndarray, angle: float) -> np.ndarray: | |
| """ | |
| Convert plucker coordinates to a 4x4 transformation matrix. | |
| """ | |
| assert plucker.shape == (6,) | |
| axis, point = plucker_to_axis_point(plucker) | |
| K = np.array([ | |
| [0, -axis[2], axis[1]], | |
| [axis[2], 0, -axis[0]], | |
| [-axis[1], axis[0], 0] | |
| ]) | |
| I = np.eye(3) | |
| R = I + np.sin(angle) * K + (1 - np.cos(angle)) * (K @ K) | |
| T = np.eye(4) | |
| T[:3, :3] = R | |
| T[:3, 3] = point - R @ point | |
| return T | |
| def transform_points(points: np.ndarray, transform_matrix: np.ndarray) -> np.ndarray: | |
| """ | |
| Transform points by a 4x4 transformation matrix. | |
| points: (..., 3) | |
| transform_matrix: (4, 4) | |
| """ | |
| return points @ transform_matrix[:3, :3].T + transform_matrix[:3, 3] | |
| def transform_direction(direction: np.ndarray, transform_matrix: np.ndarray) -> np.ndarray: | |
| """ | |
| Transform a direction vector by a 4x4 transformation matrix. | |
| direction: (..., 3) | |
| transform_matrix: (4, 4) | |
| """ | |
| return direction @ transform_matrix[:3, :3].T | |
| def transform_plucker(plucker: np.ndarray, transform_matrix: np.ndarray) -> np.ndarray: | |
| """Transforms a Plucker line by a 4x4 transform matrix.""" | |
| axis, point = plucker_to_axis_point(np.asarray(plucker, dtype=np.float32)) | |
| transformed_axis = transform_direction(axis, transform_matrix) | |
| transformed_point = transform_points(point, transform_matrix) | |
| return axis_point_to_plucker(transformed_axis, transformed_point).astype( | |
| np.float32, | |
| copy=False, | |
| ) | |
| def get_subtree_part_ids(motion_hierarchy: List[Tuple[int, int]], part_id: int) -> List[int]: | |
| """ | |
| Get the subtree part ids for a given part id. | |
| """ | |
| subtree_part_ids = [part_id] | |
| for parent_id, child_id in motion_hierarchy: | |
| if parent_id == part_id: | |
| subtree_part_ids.extend(get_subtree_part_ids(motion_hierarchy, child_id)) | |
| return subtree_part_ids | |
| def get_part_order_from_root(motion_hierarchy: List[Tuple[int, int]]) -> List[int]: | |
| """ | |
| Depth-first search to get the part order from the root. | |
| """ | |
| part_order = [] | |
| visited = set() | |
| def dfs(part_id): | |
| if part_id in visited: | |
| return | |
| part_order.append(part_id) | |
| visited.add(part_id) | |
| for parent_id, child_id in motion_hierarchy: | |
| if parent_id == part_id: | |
| dfs(child_id) | |
| # Find the base/root part id | |
| all_part_ids = set([parent_id for parent_id, _ in motion_hierarchy]) | |
| all_part_ids.update([child_id for _, child_id in motion_hierarchy]) | |
| # Find the root part id | |
| for _, child_id in motion_hierarchy: | |
| all_part_ids.remove(child_id) | |
| # assert len(all_part_ids) == 1 | |
| root_part_id = all_part_ids.pop() | |
| dfs(root_part_id) # Populate part_order | |
| return part_order | |
| def compute_part_transforms( | |
| unique_part_ids, | |
| motion_hierarchy, | |
| is_part_revolute, | |
| is_part_prismatic, | |
| revolute_plucker, | |
| revolute_range, | |
| prismatic_axis, | |
| prismatic_range, | |
| articulation_state | |
| ): | |
| """ | |
| Compute the 4x4 transformation matrix for each part at a given articulation state. | |
| Returns a dictionary mapping part_id to its cumulative transformation matrix. | |
| The transformation represents how to transform each part from its rest pose to the articulated pose. | |
| """ | |
| if len(motion_hierarchy) == 0: | |
| return {pid: np.eye(4) for pid in unique_part_ids} | |
| # Collect all relevant part IDs from motion hierarchy and unique_part_ids | |
| all_part_ids = set(unique_part_ids) | |
| for parent, child in motion_hierarchy: | |
| all_part_ids.add(parent) | |
| all_part_ids.add(child) | |
| transforms = {pid: np.eye(4) for pid in all_part_ids} | |
| # Process parts in hierarchical order (BFS/DFS from root) | |
| part_order = get_part_order_from_root(motion_hierarchy) | |
| for pid in part_order: | |
| affected_part_ids = get_subtree_part_ids(motion_hierarchy, pid) | |
| part_articulation_state = ( | |
| articulation_state | |
| if np.isscalar(articulation_state) or np.asarray(articulation_state).ndim == 0 | |
| else articulation_state[pid] | |
| ) | |
| # Compute transformation for this part's joint | |
| joint_transform = np.eye(4) | |
| if is_part_revolute[pid]: | |
| low_limit, high_limit = revolute_range[pid] | |
| angle = low_limit + part_articulation_state * (high_limit - low_limit) | |
| joint_transform = plucker_to_4x4_transform_matrix(revolute_plucker[pid], angle) | |
| elif is_part_prismatic[pid]: | |
| low_limit, high_limit = prismatic_range[pid] | |
| displacement = low_limit + part_articulation_state * (high_limit - low_limit) | |
| paxis = prismatic_axis[pid] | |
| joint_transform[:3, 3] = displacement * paxis | |
| # Apply joint transformation to all affected (descendant) parts | |
| for affected_pid in affected_part_ids: | |
| if affected_pid in transforms: | |
| transforms[affected_pid] = joint_transform @ transforms[affected_pid] | |
| return transforms | |