rayli's picture
Clean unused demo logic
13116e0 verified
Raw
History Blame Contribute Delete
5.95 kB
import numpy as np
from typing import List, Tuple
def axis_point_to_plucker(axis: np.ndarray, point: np.ndarray) -> np.ndarray:
"""
Convert axis-point coordinates to plucker coordinates.
"""
assert axis.shape[-1] == 3
assert point.shape[-1] == 3
l = axis / (np.linalg.norm(axis, axis=-1, keepdims=True) + 1e-8)
m = np.cross(l, point, axis=-1)
return np.concatenate([l, m], axis=-1)
def plucker_to_axis_point(plucker: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Convert plucker coordinates to axis-point coordinates.
"""
assert plucker.shape[-1] == 6
l, m = plucker[..., :3], plucker[..., 3:]
axis = l / (np.linalg.norm(l, axis=-1, keepdims=True) + 1e-8)
point = np.cross(m, axis, axis=-1)
return axis, point
def plucker_to_4x4_transform_matrix(plucker: np.ndarray, angle: float) -> np.ndarray:
"""
Convert plucker coordinates to a 4x4 transformation matrix.
"""
assert plucker.shape == (6,)
axis, point = plucker_to_axis_point(plucker)
K = np.array([
[0, -axis[2], axis[1]],
[axis[2], 0, -axis[0]],
[-axis[1], axis[0], 0]
])
I = np.eye(3)
R = I + np.sin(angle) * K + (1 - np.cos(angle)) * (K @ K)
T = np.eye(4)
T[:3, :3] = R
T[:3, 3] = point - R @ point
return T
def transform_points(points: np.ndarray, transform_matrix: np.ndarray) -> np.ndarray:
"""
Transform points by a 4x4 transformation matrix.
points: (..., 3)
transform_matrix: (4, 4)
"""
return points @ transform_matrix[:3, :3].T + transform_matrix[:3, 3]
def transform_direction(direction: np.ndarray, transform_matrix: np.ndarray) -> np.ndarray:
"""
Transform a direction vector by a 4x4 transformation matrix.
direction: (..., 3)
transform_matrix: (4, 4)
"""
return direction @ transform_matrix[:3, :3].T
def transform_plucker(plucker: np.ndarray, transform_matrix: np.ndarray) -> np.ndarray:
"""Transforms a Plucker line by a 4x4 transform matrix."""
axis, point = plucker_to_axis_point(np.asarray(plucker, dtype=np.float32))
transformed_axis = transform_direction(axis, transform_matrix)
transformed_point = transform_points(point, transform_matrix)
return axis_point_to_plucker(transformed_axis, transformed_point).astype(
np.float32,
copy=False,
)
def get_subtree_part_ids(motion_hierarchy: List[Tuple[int, int]], part_id: int) -> List[int]:
"""
Get the subtree part ids for a given part id.
"""
subtree_part_ids = [part_id]
for parent_id, child_id in motion_hierarchy:
if parent_id == part_id:
subtree_part_ids.extend(get_subtree_part_ids(motion_hierarchy, child_id))
return subtree_part_ids
def get_part_order_from_root(motion_hierarchy: List[Tuple[int, int]]) -> List[int]:
"""
Depth-first search to get the part order from the root.
"""
part_order = []
visited = set()
def dfs(part_id):
if part_id in visited:
return
part_order.append(part_id)
visited.add(part_id)
for parent_id, child_id in motion_hierarchy:
if parent_id == part_id:
dfs(child_id)
# Find the base/root part id
all_part_ids = set([parent_id for parent_id, _ in motion_hierarchy])
all_part_ids.update([child_id for _, child_id in motion_hierarchy])
# Find the root part id
for _, child_id in motion_hierarchy:
all_part_ids.remove(child_id)
# assert len(all_part_ids) == 1
root_part_id = all_part_ids.pop()
dfs(root_part_id) # Populate part_order
return part_order
def compute_part_transforms(
unique_part_ids,
motion_hierarchy,
is_part_revolute,
is_part_prismatic,
revolute_plucker,
revolute_range,
prismatic_axis,
prismatic_range,
articulation_state
):
"""
Compute the 4x4 transformation matrix for each part at a given articulation state.
Returns a dictionary mapping part_id to its cumulative transformation matrix.
The transformation represents how to transform each part from its rest pose to the articulated pose.
"""
if len(motion_hierarchy) == 0:
return {pid: np.eye(4) for pid in unique_part_ids}
# Collect all relevant part IDs from motion hierarchy and unique_part_ids
all_part_ids = set(unique_part_ids)
for parent, child in motion_hierarchy:
all_part_ids.add(parent)
all_part_ids.add(child)
transforms = {pid: np.eye(4) for pid in all_part_ids}
# Process parts in hierarchical order (BFS/DFS from root)
part_order = get_part_order_from_root(motion_hierarchy)
for pid in part_order:
affected_part_ids = get_subtree_part_ids(motion_hierarchy, pid)
part_articulation_state = (
articulation_state
if np.isscalar(articulation_state) or np.asarray(articulation_state).ndim == 0
else articulation_state[pid]
)
# Compute transformation for this part's joint
joint_transform = np.eye(4)
if is_part_revolute[pid]:
low_limit, high_limit = revolute_range[pid]
angle = low_limit + part_articulation_state * (high_limit - low_limit)
joint_transform = plucker_to_4x4_transform_matrix(revolute_plucker[pid], angle)
elif is_part_prismatic[pid]:
low_limit, high_limit = prismatic_range[pid]
displacement = low_limit + part_articulation_state * (high_limit - low_limit)
paxis = prismatic_axis[pid]
joint_transform[:3, 3] = displacement * paxis
# Apply joint transformation to all affected (descendant) parts
for affected_pid in affected_part_ids:
if affected_pid in transforms:
transforms[affected_pid] = joint_transform @ transforms[affected_pid]
return transforms