File size: 7,676 Bytes
c1f1d32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
from typing import Dict, Tuple, Union
import copy
import torch
import torch.nn as nn
import torchvision
from equi_diffpo.model.vision.crop_randomizer import CropRandomizer
from equi_diffpo.model.common.module_attr_mixin import ModuleAttrMixin
from equi_diffpo.common.pytorch_util import dict_apply, replace_submodules
class MultiImageObsEncoder(ModuleAttrMixin):
def __init__(self,
shape_meta: dict,
rgb_model: Union[nn.Module, Dict[str,nn.Module]],
resize_shape: Union[Tuple[int,int], Dict[str,tuple], None]=None,
crop_shape: Union[Tuple[int,int], Dict[str,tuple], None]=None,
random_crop: bool=True,
# replace BatchNorm with GroupNorm
use_group_norm: bool=False,
# use single rgb model for all rgb inputs
share_rgb_model: bool=False,
# renormalize rgb input with imagenet normalization
# assuming input in [0,1]
imagenet_norm: bool=False
):
"""
Assumes rgb input: B,C,H,W
Assumes low_dim input: B,D
"""
super().__init__()
rgb_keys = list()
low_dim_keys = list()
key_model_map = nn.ModuleDict()
key_transform_map = nn.ModuleDict()
key_shape_map = dict()
# handle sharing vision backbone
if share_rgb_model:
assert isinstance(rgb_model, nn.Module)
key_model_map['rgb'] = rgb_model
obs_shape_meta = shape_meta['obs']
for key, attr in obs_shape_meta.items():
shape = tuple(attr['shape'])
type = attr.get('type', 'low_dim')
key_shape_map[key] = shape
if type == 'rgb':
rgb_keys.append(key)
# configure model for this key
this_model = None
if not share_rgb_model:
if isinstance(rgb_model, dict):
# have provided model for each key
this_model = rgb_model[key]
else:
assert isinstance(rgb_model, nn.Module)
# have a copy of the rgb model
this_model = copy.deepcopy(rgb_model)
if this_model is not None:
if use_group_norm:
this_model = replace_submodules(
root_module=this_model,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features//16,
num_channels=x.num_features)
)
key_model_map[key] = this_model
# configure resize
input_shape = shape
this_resizer = nn.Identity()
if resize_shape is not None:
if isinstance(resize_shape, dict):
h, w = resize_shape[key]
else:
h, w = resize_shape
this_resizer = torchvision.transforms.Resize(
size=(h,w)
)
input_shape = (shape[0],h,w)
# configure randomizer
this_randomizer = nn.Identity()
if crop_shape is not None:
if isinstance(crop_shape, dict):
h, w = crop_shape[key]
else:
h, w = crop_shape
if random_crop:
this_randomizer = CropRandomizer(
input_shape=input_shape,
crop_height=h,
crop_width=w,
num_crops=1,
pos_enc=False
)
else:
this_normalizer = torchvision.transforms.CenterCrop(
size=(h,w)
)
# configure normalizer
this_normalizer = nn.Identity()
if imagenet_norm:
this_normalizer = torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
key_transform_map[key] = this_transform
elif type == 'low_dim':
low_dim_keys.append(key)
else:
raise RuntimeError(f"Unsupported obs type: {type}")
rgb_keys = sorted(rgb_keys)
low_dim_keys = sorted(low_dim_keys)
self.shape_meta = shape_meta
self.key_model_map = key_model_map
self.key_transform_map = key_transform_map
self.share_rgb_model = share_rgb_model
self.rgb_keys = rgb_keys
self.low_dim_keys = low_dim_keys
self.key_shape_map = key_shape_map
def forward(self, obs_dict):
batch_size = None
features = list()
# process rgb input
if self.share_rgb_model:
# pass all rgb obs to rgb model
imgs = list()
for key in self.rgb_keys:
img = obs_dict[key]
if batch_size is None:
batch_size = img.shape[0]
else:
assert batch_size == img.shape[0]
assert img.shape[1:] == self.key_shape_map[key]
img = self.key_transform_map[key](img)
imgs.append(img)
# (N*B,C,H,W)
imgs = torch.cat(imgs, dim=0)
# (N*B,D)
feature = self.key_model_map['rgb'](imgs)
# (N,B,D)
feature = feature.reshape(-1,batch_size,*feature.shape[1:])
# (B,N,D)
feature = torch.moveaxis(feature,0,1)
# (B,N*D)
feature = feature.reshape(batch_size,-1)
features.append(feature)
else:
# run each rgb obs to independent models
for key in self.rgb_keys:
img = obs_dict[key]
if batch_size is None:
batch_size = img.shape[0]
else:
assert batch_size == img.shape[0]
assert img.shape[1:] == self.key_shape_map[key]
img = self.key_transform_map[key](img)
feature = self.key_model_map[key](img)
features.append(feature)
# process lowdim input
for key in self.low_dim_keys:
data = obs_dict[key]
if batch_size is None:
batch_size = data.shape[0]
else:
assert batch_size == data.shape[0]
assert data.shape[1:] == self.key_shape_map[key]
features.append(data)
# concatenate all features
result = torch.cat(features, dim=-1)
return result
@torch.no_grad()
def output_shape(self):
example_obs_dict = dict()
obs_shape_meta = self.shape_meta['obs']
batch_size = 1
for key, attr in obs_shape_meta.items():
shape = tuple(attr['shape'])
this_obs = torch.zeros(
(batch_size,) + shape,
dtype=self.dtype,
device=self.device)
example_obs_dict[key] = this_obs
example_output = self.forward(example_obs_dict)
output_shape = example_output.shape[1:]
return output_shape
|