| import numpy as np |
| import torch |
| import zarr |
| from zarr.storage import DirectoryStore |
| from zarr import LRUStoreCache |
|
|
| import utils.pytorch3d_transforms as pytorch3d_transforms |
|
|
|
|
| def to_tensor(x): |
| if isinstance(x, torch.Tensor): |
| return x |
| elif isinstance(x, np.ndarray): |
| return torch.from_numpy(x) |
| else: |
| return torch.as_tensor(x) |
|
|
|
|
| def read_zarr_with_cache(fname, mem_gb=16): |
| |
| store = DirectoryStore(fname) |
|
|
| |
| cached_store = LRUStoreCache(store, max_size=mem_gb * 2**30) |
|
|
| |
| return zarr.open_group(cached_store, mode="r") |
|
|
|
|
| def to_relative_action(actions, anchor_action, qform='xyzw'): |
| """ |
| Compute delta actions where the first delta is relative to anchor, |
| and subsequent deltas are relative to the previous timestep. |
| |
| Args: |
| actions: (..., N, 8) — future trajectory |
| anchor_action: (..., 1, 8) — current pose to treat as timestep -1 |
| qform: 'xyzw' or 'wxyz' — quaternion format |
| |
| Returns: |
| delta_actions: (..., N, 8) |
| """ |
| assert actions.shape[-1] == 8 |
| |
| prev = torch.cat([anchor_action, actions[..., :-1, :]], -2) |
|
|
| rel_pos = actions[..., :3] - prev[..., :3] |
|
|
| if qform == 'xyzw': |
| rel_orn = pytorch3d_transforms.quaternion_multiply( |
| actions[..., [6, 3, 4, 5]], |
| pytorch3d_transforms.quaternion_invert(prev[..., [6, 3, 4, 5]]) |
| )[..., [1, 2, 3, 0]] |
| elif qform == 'wxyz': |
| rel_orn = pytorch3d_transforms.quaternion_multiply( |
| actions[..., 3:7], |
| pytorch3d_transforms.quaternion_invert(prev[..., 3:7]) |
| ) |
| else: |
| raise ValueError("Invalid quaternion format") |
|
|
| gripper = actions[..., -1:] |
|
|
| return torch.cat([rel_pos, rel_orn, gripper], -1) |
|
|