Spaces:
Running on Zero
Running on Zero
File size: 5,952 Bytes
2dd4628 13116e0 2dd4628 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | import numpy as np
from typing import List, Tuple
def axis_point_to_plucker(axis: np.ndarray, point: np.ndarray) -> np.ndarray:
"""
Convert axis-point coordinates to plucker coordinates.
"""
assert axis.shape[-1] == 3
assert point.shape[-1] == 3
l = axis / (np.linalg.norm(axis, axis=-1, keepdims=True) + 1e-8)
m = np.cross(l, point, axis=-1)
return np.concatenate([l, m], axis=-1)
def plucker_to_axis_point(plucker: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Convert plucker coordinates to axis-point coordinates.
"""
assert plucker.shape[-1] == 6
l, m = plucker[..., :3], plucker[..., 3:]
axis = l / (np.linalg.norm(l, axis=-1, keepdims=True) + 1e-8)
point = np.cross(m, axis, axis=-1)
return axis, point
def plucker_to_4x4_transform_matrix(plucker: np.ndarray, angle: float) -> np.ndarray:
"""
Convert plucker coordinates to a 4x4 transformation matrix.
"""
assert plucker.shape == (6,)
axis, point = plucker_to_axis_point(plucker)
K = np.array([
[0, -axis[2], axis[1]],
[axis[2], 0, -axis[0]],
[-axis[1], axis[0], 0]
])
I = np.eye(3)
R = I + np.sin(angle) * K + (1 - np.cos(angle)) * (K @ K)
T = np.eye(4)
T[:3, :3] = R
T[:3, 3] = point - R @ point
return T
def transform_points(points: np.ndarray, transform_matrix: np.ndarray) -> np.ndarray:
"""
Transform points by a 4x4 transformation matrix.
points: (..., 3)
transform_matrix: (4, 4)
"""
return points @ transform_matrix[:3, :3].T + transform_matrix[:3, 3]
def transform_direction(direction: np.ndarray, transform_matrix: np.ndarray) -> np.ndarray:
"""
Transform a direction vector by a 4x4 transformation matrix.
direction: (..., 3)
transform_matrix: (4, 4)
"""
return direction @ transform_matrix[:3, :3].T
def transform_plucker(plucker: np.ndarray, transform_matrix: np.ndarray) -> np.ndarray:
"""Transforms a Plucker line by a 4x4 transform matrix."""
axis, point = plucker_to_axis_point(np.asarray(plucker, dtype=np.float32))
transformed_axis = transform_direction(axis, transform_matrix)
transformed_point = transform_points(point, transform_matrix)
return axis_point_to_plucker(transformed_axis, transformed_point).astype(
np.float32,
copy=False,
)
def get_subtree_part_ids(motion_hierarchy: List[Tuple[int, int]], part_id: int) -> List[int]:
"""
Get the subtree part ids for a given part id.
"""
subtree_part_ids = [part_id]
for parent_id, child_id in motion_hierarchy:
if parent_id == part_id:
subtree_part_ids.extend(get_subtree_part_ids(motion_hierarchy, child_id))
return subtree_part_ids
def get_part_order_from_root(motion_hierarchy: List[Tuple[int, int]]) -> List[int]:
"""
Depth-first search to get the part order from the root.
"""
part_order = []
visited = set()
def dfs(part_id):
if part_id in visited:
return
part_order.append(part_id)
visited.add(part_id)
for parent_id, child_id in motion_hierarchy:
if parent_id == part_id:
dfs(child_id)
# Find the base/root part id
all_part_ids = set([parent_id for parent_id, _ in motion_hierarchy])
all_part_ids.update([child_id for _, child_id in motion_hierarchy])
# Find the root part id
for _, child_id in motion_hierarchy:
all_part_ids.remove(child_id)
# assert len(all_part_ids) == 1
root_part_id = all_part_ids.pop()
dfs(root_part_id) # Populate part_order
return part_order
def compute_part_transforms(
unique_part_ids,
motion_hierarchy,
is_part_revolute,
is_part_prismatic,
revolute_plucker,
revolute_range,
prismatic_axis,
prismatic_range,
articulation_state
):
"""
Compute the 4x4 transformation matrix for each part at a given articulation state.
Returns a dictionary mapping part_id to its cumulative transformation matrix.
The transformation represents how to transform each part from its rest pose to the articulated pose.
"""
if len(motion_hierarchy) == 0:
return {pid: np.eye(4) for pid in unique_part_ids}
# Collect all relevant part IDs from motion hierarchy and unique_part_ids
all_part_ids = set(unique_part_ids)
for parent, child in motion_hierarchy:
all_part_ids.add(parent)
all_part_ids.add(child)
transforms = {pid: np.eye(4) for pid in all_part_ids}
# Process parts in hierarchical order (BFS/DFS from root)
part_order = get_part_order_from_root(motion_hierarchy)
for pid in part_order:
affected_part_ids = get_subtree_part_ids(motion_hierarchy, pid)
part_articulation_state = (
articulation_state
if np.isscalar(articulation_state) or np.asarray(articulation_state).ndim == 0
else articulation_state[pid]
)
# Compute transformation for this part's joint
joint_transform = np.eye(4)
if is_part_revolute[pid]:
low_limit, high_limit = revolute_range[pid]
angle = low_limit + part_articulation_state * (high_limit - low_limit)
joint_transform = plucker_to_4x4_transform_matrix(revolute_plucker[pid], angle)
elif is_part_prismatic[pid]:
low_limit, high_limit = prismatic_range[pid]
displacement = low_limit + part_articulation_state * (high_limit - low_limit)
paxis = prismatic_axis[pid]
joint_transform[:3, 3] = displacement * paxis
# Apply joint transformation to all affected (descendant) parts
for affected_pid in affected_part_ids:
if affected_pid in transforms:
transforms[affected_pid] = joint_transform @ transforms[affected_pid]
return transforms
|