| | from typing import Dict, Tuple, Union |
| | import copy |
| | import torch |
| | import torch.nn as nn |
| | import torchvision |
| | from diffusion_policy.model.vision.crop_randomizer import CropRandomizer |
| | from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin |
| | from diffusion_policy.common.pytorch_util import dict_apply, replace_submodules |
| |
|
| |
|
| | class MultiImageObsEncoder(ModuleAttrMixin): |
| | def __init__(self, |
| | shape_meta: dict, |
| | rgb_model: Union[nn.Module, Dict[str,nn.Module]], |
| | resize_shape: Union[Tuple[int,int], Dict[str,tuple], None]=None, |
| | crop_shape: Union[Tuple[int,int], Dict[str,tuple], None]=None, |
| | random_crop: bool=True, |
| | |
| | use_group_norm: bool=False, |
| | |
| | share_rgb_model: bool=False, |
| | |
| | |
| | imagenet_norm: bool=False |
| | ): |
| | """ |
| | Assumes rgb input: B,C,H,W |
| | Assumes low_dim input: B,D |
| | """ |
| | super().__init__() |
| |
|
| | rgb_keys = list() |
| | low_dim_keys = list() |
| | key_model_map = nn.ModuleDict() |
| | key_transform_map = nn.ModuleDict() |
| | key_shape_map = dict() |
| |
|
| | |
| | if share_rgb_model: |
| | assert isinstance(rgb_model, nn.Module) |
| | key_model_map['rgb'] = rgb_model |
| |
|
| | obs_shape_meta = shape_meta['obs'] |
| | for key, attr in obs_shape_meta.items(): |
| | shape = tuple(attr['shape']) |
| | type = attr.get('type', 'low_dim') |
| | key_shape_map[key] = shape |
| | if type == 'rgb': |
| | rgb_keys.append(key) |
| | |
| | this_model = None |
| | if not share_rgb_model: |
| | if isinstance(rgb_model, dict): |
| | |
| | this_model = rgb_model[key] |
| | else: |
| | assert isinstance(rgb_model, nn.Module) |
| | |
| | this_model = copy.deepcopy(rgb_model) |
| | |
| | if this_model is not None: |
| | if use_group_norm: |
| | this_model = replace_submodules( |
| | root_module=this_model, |
| | predicate=lambda x: isinstance(x, nn.BatchNorm2d), |
| | func=lambda x: nn.GroupNorm( |
| | num_groups=x.num_features//16, |
| | num_channels=x.num_features) |
| | ) |
| | key_model_map[key] = this_model |
| | |
| | |
| | input_shape = shape |
| | this_resizer = nn.Identity() |
| | if resize_shape is not None: |
| | if isinstance(resize_shape, dict): |
| | h, w = resize_shape[key] |
| | else: |
| | h, w = resize_shape |
| | this_resizer = torchvision.transforms.Resize( |
| | size=(h,w) |
| | ) |
| | input_shape = (shape[0],h,w) |
| |
|
| | |
| | this_randomizer = nn.Identity() |
| | if crop_shape is not None: |
| | if isinstance(crop_shape, dict): |
| | h, w = crop_shape[key] |
| | else: |
| | h, w = crop_shape |
| | if random_crop: |
| | this_randomizer = CropRandomizer( |
| | input_shape=input_shape, |
| | crop_height=h, |
| | crop_width=w, |
| | num_crops=1, |
| | pos_enc=False |
| | ) |
| | else: |
| | this_normalizer = torchvision.transforms.CenterCrop( |
| | size=(h,w) |
| | ) |
| | |
| | this_normalizer = nn.Identity() |
| | if imagenet_norm: |
| | this_normalizer = torchvision.transforms.Normalize( |
| | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| | |
| | this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer) |
| | key_transform_map[key] = this_transform |
| | elif type == 'low_dim': |
| | low_dim_keys.append(key) |
| | else: |
| | raise RuntimeError(f"Unsupported obs type: {type}") |
| | rgb_keys = sorted(rgb_keys) |
| | low_dim_keys = sorted(low_dim_keys) |
| |
|
| | self.shape_meta = shape_meta |
| | self.key_model_map = key_model_map |
| | self.key_transform_map = key_transform_map |
| | self.share_rgb_model = share_rgb_model |
| | self.rgb_keys = rgb_keys |
| | self.low_dim_keys = low_dim_keys |
| | self.key_shape_map = key_shape_map |
| |
|
| | def forward(self, obs_dict): |
| | batch_size = None |
| | features = list() |
| | |
| | if self.share_rgb_model: |
| | |
| | imgs = list() |
| | for key in self.rgb_keys: |
| | img = obs_dict[key] |
| | if batch_size is None: |
| | batch_size = img.shape[0] |
| | else: |
| | assert batch_size == img.shape[0] |
| | assert img.shape[1:] == self.key_shape_map[key] |
| | img = self.key_transform_map[key](img) |
| | imgs.append(img) |
| | |
| | imgs = torch.cat(imgs, dim=0) |
| | |
| | feature = self.key_model_map['rgb'](imgs) |
| | |
| | feature = feature.reshape(-1,batch_size,*feature.shape[1:]) |
| | |
| | feature = torch.moveaxis(feature,0,1) |
| | |
| | feature = feature.reshape(batch_size,-1) |
| | features.append(feature) |
| | else: |
| | |
| | for key in self.rgb_keys: |
| | img = obs_dict[key] |
| | if batch_size is None: |
| | batch_size = img.shape[0] |
| | else: |
| | assert batch_size == img.shape[0] |
| | assert img.shape[1:] == self.key_shape_map[key] |
| | img = self.key_transform_map[key](img) |
| | feature = self.key_model_map[key](img) |
| | features.append(feature) |
| | |
| | |
| | for key in self.low_dim_keys: |
| | data = obs_dict[key] |
| | if batch_size is None: |
| | batch_size = data.shape[0] |
| | else: |
| | assert batch_size == data.shape[0] |
| | assert data.shape[1:] == self.key_shape_map[key] |
| | features.append(data) |
| | |
| | |
| | result = torch.cat(features, dim=-1) |
| | return result |
| | |
| | @torch.no_grad() |
| | def output_shape(self): |
| | example_obs_dict = dict() |
| | obs_shape_meta = self.shape_meta['obs'] |
| | batch_size = 1 |
| | for key, attr in obs_shape_meta.items(): |
| | shape = tuple(attr['shape']) |
| | this_obs = torch.zeros( |
| | (batch_size,) + shape, |
| | dtype=self.dtype, |
| | device=self.device) |
| | example_obs_dict[key] = this_obs |
| | example_output = self.forward(example_obs_dict) |
| | output_shape = example_output.shape[1:] |
| | return output_shape |
| |
|