| | from typing import Sequence, Optional |
| | import torch |
| | from torch import nn |
| | from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin |
| |
|
| |
|
| | def get_intersection_slice_mask( |
| | shape: tuple, |
| | dim_slices: Sequence[slice], |
| | device: Optional[torch.device]=None |
| | ): |
| | assert(len(shape) == len(dim_slices)) |
| | mask = torch.zeros(size=shape, dtype=torch.bool, device=device) |
| | mask[dim_slices] = True |
| | return mask |
| |
|
| |
|
| | def get_union_slice_mask( |
| | shape: tuple, |
| | dim_slices: Sequence[slice], |
| | device: Optional[torch.device]=None |
| | ): |
| | assert(len(shape) == len(dim_slices)) |
| | mask = torch.zeros(size=shape, dtype=torch.bool, device=device) |
| | for i in range(len(dim_slices)): |
| | this_slices = [slice(None)] * len(shape) |
| | this_slices[i] = dim_slices[i] |
| | mask[this_slices] = True |
| | return mask |
| |
|
| |
|
| | class DummyMaskGenerator(ModuleAttrMixin): |
| | def __init__(self): |
| | super().__init__() |
| | |
| | @torch.no_grad() |
| | def forward(self, shape): |
| | device = self.device |
| | mask = torch.ones(size=shape, dtype=torch.bool, device=device) |
| | return mask |
| |
|
| |
|
| | class LowdimMaskGenerator(ModuleAttrMixin): |
| | def __init__(self, |
| | action_dim, obs_dim, |
| | |
| | max_n_obs_steps=2, |
| | fix_obs_steps=True, |
| | |
| | action_visible=False |
| | ): |
| | super().__init__() |
| | self.action_dim = action_dim |
| | self.obs_dim = obs_dim |
| | self.max_n_obs_steps = max_n_obs_steps |
| | self.fix_obs_steps = fix_obs_steps |
| | self.action_visible = action_visible |
| |
|
| | @torch.no_grad() |
| | def forward(self, shape, seed=None): |
| | device = self.device |
| | B, T, D = shape |
| | assert D == (self.action_dim + self.obs_dim) |
| |
|
| | |
| | rng = torch.Generator(device=device) |
| | if seed is not None: |
| | rng = rng.manual_seed(seed) |
| |
|
| | |
| | dim_mask = torch.zeros(size=shape, |
| | dtype=torch.bool, device=device) |
| | is_action_dim = dim_mask.clone() |
| | is_action_dim[...,:self.action_dim] = True |
| | is_obs_dim = ~is_action_dim |
| |
|
| | |
| | if self.fix_obs_steps: |
| | obs_steps = torch.full((B,), |
| | fill_value=self.max_n_obs_steps, device=device) |
| | else: |
| | obs_steps = torch.randint( |
| | low=1, high=self.max_n_obs_steps+1, |
| | size=(B,), generator=rng, device=device) |
| | |
| | steps = torch.arange(0, T, device=device).reshape(1,T).expand(B,T) |
| | obs_mask = (steps.T < obs_steps).T.reshape(B,T,1).expand(B,T,D) |
| | obs_mask = obs_mask & is_obs_dim |
| |
|
| | |
| | if self.action_visible: |
| | action_steps = torch.maximum( |
| | obs_steps - 1, |
| | torch.tensor(0, |
| | dtype=obs_steps.dtype, |
| | device=obs_steps.device)) |
| | action_mask = (steps.T < action_steps).T.reshape(B,T,1).expand(B,T,D) |
| | action_mask = action_mask & is_action_dim |
| |
|
| | mask = obs_mask |
| | if self.action_visible: |
| | mask = mask | action_mask |
| | |
| | return mask |
| |
|
| |
|
| | class KeypointMaskGenerator(ModuleAttrMixin): |
| | def __init__(self, |
| | |
| | action_dim, keypoint_dim, |
| | |
| | max_n_obs_steps=2, fix_obs_steps=True, |
| | |
| | keypoint_visible_rate=0.7, time_independent=False, |
| | |
| | action_visible=False, |
| | context_dim=0, |
| | n_context_steps=1 |
| | ): |
| | super().__init__() |
| | self.action_dim = action_dim |
| | self.keypoint_dim = keypoint_dim |
| | self.context_dim = context_dim |
| | self.max_n_obs_steps = max_n_obs_steps |
| | self.fix_obs_steps = fix_obs_steps |
| | self.keypoint_visible_rate = keypoint_visible_rate |
| | self.time_independent = time_independent |
| | self.action_visible = action_visible |
| | self.n_context_steps = n_context_steps |
| | |
| | @torch.no_grad() |
| | def forward(self, shape, seed=None): |
| | device = self.device |
| | B, T, D = shape |
| | all_keypoint_dims = D - self.action_dim - self.context_dim |
| | n_keypoints = all_keypoint_dims // self.keypoint_dim |
| | |
| | |
| | rng = torch.Generator(device=device) |
| | if seed is not None: |
| | rng = rng.manual_seed(seed) |
| | |
| | |
| | dim_mask = torch.zeros(size=shape, |
| | dtype=torch.bool, device=device) |
| | is_action_dim = dim_mask.clone() |
| | is_action_dim[...,:self.action_dim] = True |
| | is_context_dim = dim_mask.clone() |
| | if self.context_dim > 0: |
| | is_context_dim[...,-self.context_dim:] = True |
| | is_obs_dim = ~(is_action_dim | is_context_dim) |
| | |
| |
|
| | |
| | if self.fix_obs_steps: |
| | obs_steps = torch.full((B,), |
| | fill_value=self.max_n_obs_steps, device=device) |
| | else: |
| | obs_steps = torch.randint( |
| | low=1, high=self.max_n_obs_steps+1, |
| | size=(B,), generator=rng, device=device) |
| | |
| | steps = torch.arange(0, T, device=device).reshape(1,T).expand(B,T) |
| | obs_mask = (steps.T < obs_steps).T.reshape(B,T,1).expand(B,T,D) |
| | obs_mask = obs_mask & is_obs_dim |
| |
|
| | |
| | if self.action_visible: |
| | action_steps = torch.maximum( |
| | obs_steps - 1, |
| | torch.tensor(0, |
| | dtype=obs_steps.dtype, |
| | device=obs_steps.device)) |
| | action_mask = (steps.T < action_steps).T.reshape(B,T,1).expand(B,T,D) |
| | action_mask = action_mask & is_action_dim |
| |
|
| | |
| | if self.time_independent: |
| | visible_kps = torch.rand(size=(B, T, n_keypoints), |
| | generator=rng, device=device) < self.keypoint_visible_rate |
| | visible_dims = torch.repeat_interleave(visible_kps, repeats=self.keypoint_dim, dim=-1) |
| | visible_dims_mask = torch.cat([ |
| | torch.ones((B, T, self.action_dim), |
| | dtype=torch.bool, device=device), |
| | visible_dims, |
| | torch.ones((B, T, self.context_dim), |
| | dtype=torch.bool, device=device), |
| | ], axis=-1) |
| | keypoint_mask = visible_dims_mask |
| | else: |
| | visible_kps = torch.rand(size=(B,n_keypoints), |
| | generator=rng, device=device) < self.keypoint_visible_rate |
| | visible_dims = torch.repeat_interleave(visible_kps, repeats=self.keypoint_dim, dim=-1) |
| | visible_dims_mask = torch.cat([ |
| | torch.ones((B, self.action_dim), |
| | dtype=torch.bool, device=device), |
| | visible_dims, |
| | torch.ones((B, self.context_dim), |
| | dtype=torch.bool, device=device), |
| | ], axis=-1) |
| | keypoint_mask = visible_dims_mask.reshape(B,1,D).expand(B,T,D) |
| | keypoint_mask = keypoint_mask & is_obs_dim |
| |
|
| | |
| | context_mask = is_context_dim.clone() |
| | context_mask[:,self.n_context_steps:,:] = False |
| |
|
| | mask = obs_mask & keypoint_mask |
| | if self.action_visible: |
| | mask = mask | action_mask |
| | if self.context_dim > 0: |
| | mask = mask | context_mask |
| |
|
| | return mask |
| |
|
| |
|
| | def test(): |
| | |
| | |
| | |
| | self = LowdimMaskGenerator(2,20, max_n_obs_steps=3, action_visible=True) |
| |
|