| import torch |
| import torch.nn as nn |
| import torchvision.transforms.functional as ttf |
| import diffusion_policy.model.common.tensor_util as tu |
|
|
| class CropRandomizer(nn.Module): |
| """ |
| Randomly sample crops at input, and then average across crop features at output. |
| """ |
| def __init__( |
| self, |
| input_shape, |
| crop_height, |
| crop_width, |
| num_crops=1, |
| pos_enc=False, |
| ): |
| """ |
| Args: |
| input_shape (tuple, list): shape of input (not including batch dimension) |
| crop_height (int): crop height |
| crop_width (int): crop width |
| num_crops (int): number of random crops to take |
| pos_enc (bool): if True, add 2 channels to the output to encode the spatial |
| location of the cropped pixels in the source image |
| """ |
| super().__init__() |
|
|
| assert len(input_shape) == 3 |
| assert crop_height < input_shape[1] |
| assert crop_width < input_shape[2] |
|
|
| self.input_shape = input_shape |
| self.crop_height = crop_height |
| self.crop_width = crop_width |
| self.num_crops = num_crops |
| self.pos_enc = pos_enc |
|
|
| def output_shape_in(self, input_shape=None): |
| """ |
| Function to compute output shape from inputs to this module. Corresponds to |
| the @forward_in operation, where raw inputs (usually observation modalities) |
| are passed in. |
| |
| Args: |
| input_shape (iterable of int): shape of input. Does not include batch dimension. |
| Some modules may not need this argument, if their output does not depend |
| on the size of the input, or if they assume fixed size input. |
| |
| Returns: |
| out_shape ([int]): list of integers corresponding to output shape |
| """ |
|
|
| |
| |
| |
| out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0] |
| return [out_c, self.crop_height, self.crop_width] |
|
|
| def output_shape_out(self, input_shape=None): |
| """ |
| Function to compute output shape from inputs to this module. Corresponds to |
| the @forward_out operation, where processed inputs (usually encoded observation |
| modalities) are passed in. |
| |
| Args: |
| input_shape (iterable of int): shape of input. Does not include batch dimension. |
| Some modules may not need this argument, if their output does not depend |
| on the size of the input, or if they assume fixed size input. |
| |
| Returns: |
| out_shape ([int]): list of integers corresponding to output shape |
| """ |
| |
| |
| |
| |
| return list(input_shape) |
|
|
| def forward_in(self, inputs): |
| """ |
| Samples N random crops for each input in the batch, and then reshapes |
| inputs to [B * N, ...]. |
| """ |
| assert len(inputs.shape) >= 3 |
| if self.training: |
| |
| out, _ = sample_random_image_crops( |
| images=inputs, |
| crop_height=self.crop_height, |
| crop_width=self.crop_width, |
| num_crops=self.num_crops, |
| pos_enc=self.pos_enc, |
| ) |
| |
| return tu.join_dimensions(out, 0, 1) |
| else: |
| |
| out = ttf.center_crop(img=inputs, output_size=( |
| self.crop_height, self.crop_width)) |
| if self.num_crops > 1: |
| B,C,H,W = out.shape |
| out = out.unsqueeze(1).expand(B,self.num_crops,C,H,W).reshape(-1,C,H,W) |
| |
| return out |
|
|
| def forward_out(self, inputs): |
| """ |
| Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N |
| to result in shape [B, ...] to make sure the network output is consistent with |
| what would have happened if there were no randomization. |
| """ |
| if self.num_crops <= 1: |
| return inputs |
| else: |
| batch_size = (inputs.shape[0] // self.num_crops) |
| out = tu.reshape_dimensions(inputs, begin_axis=0, end_axis=0, |
| target_dims=(batch_size, self.num_crops)) |
| return out.mean(dim=1) |
| |
| def forward(self, inputs): |
| return self.forward_in(inputs) |
|
|
| def __repr__(self): |
| """Pretty print network.""" |
| header = '{}'.format(str(self.__class__.__name__)) |
| msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format( |
| self.input_shape, self.crop_height, self.crop_width, self.num_crops) |
| return msg |
|
|
|
|
| def crop_image_from_indices(images, crop_indices, crop_height, crop_width): |
| """ |
| Crops images at the locations specified by @crop_indices. Crops will be |
| taken across all channels. |
| |
| Args: |
| images (torch.Tensor): batch of images of shape [..., C, H, W] |
| |
| crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where |
| N is the number of crops to take per image and each entry corresponds |
| to the pixel height and width of where to take the crop. Note that |
| the indices can also be of shape [..., 2] if only 1 crop should |
| be taken per image. Leading dimensions must be consistent with |
| @images argument. Each index specifies the top left of the crop. |
| Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where |
| H and W are the height and width of @images and CH and CW are |
| @crop_height and @crop_width. |
| |
| crop_height (int): height of crop to take |
| |
| crop_width (int): width of crop to take |
| |
| Returns: |
| crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width] |
| """ |
|
|
| |
| assert crop_indices.shape[-1] == 2 |
| ndim_im_shape = len(images.shape) |
| ndim_indices_shape = len(crop_indices.shape) |
| assert (ndim_im_shape == ndim_indices_shape + 1) or (ndim_im_shape == ndim_indices_shape + 2) |
|
|
| |
| is_padded = False |
| if ndim_im_shape == ndim_indices_shape + 2: |
| crop_indices = crop_indices.unsqueeze(-2) |
| is_padded = True |
|
|
| |
| assert images.shape[:-3] == crop_indices.shape[:-2] |
|
|
| device = images.device |
| image_c, image_h, image_w = images.shape[-3:] |
| num_crops = crop_indices.shape[-2] |
|
|
| |
| assert (crop_indices[..., 0] >= 0).all().item() |
| assert (crop_indices[..., 0] < (image_h - crop_height)).all().item() |
| assert (crop_indices[..., 1] >= 0).all().item() |
| assert (crop_indices[..., 1] < (image_w - crop_width)).all().item() |
|
|
| |
|
|
| |
| crop_ind_grid_h = torch.arange(crop_height).to(device) |
| crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h, size=crop_width, dim=-1) |
| |
| crop_ind_grid_w = torch.arange(crop_width).to(device) |
| crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w, size=crop_height, dim=0) |
| |
| crop_in_grid = torch.cat((crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1) |
|
|
| |
| |
| |
| grid_reshape = [1] * len(crop_indices.shape[:-1]) + [crop_height, crop_width, 2] |
| all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(-2) + crop_in_grid.reshape(grid_reshape) |
|
|
| |
| |
| |
| all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[..., 1] |
| all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c, dim=-3) |
| all_crop_inds = tu.flatten(all_crop_inds, begin_axis=-2) |
|
|
| |
| images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4) |
| images_to_crop = tu.flatten(images_to_crop, begin_axis=-2) |
| crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds) |
| |
| reshape_axis = len(crops.shape) - 1 |
| crops = tu.reshape_dimensions(crops, begin_axis=reshape_axis, end_axis=reshape_axis, |
| target_dims=(crop_height, crop_width)) |
|
|
| if is_padded: |
| |
| crops = crops.squeeze(-4) |
| return crops |
|
|
| def sample_random_image_crops(images, crop_height, crop_width, num_crops, pos_enc=False): |
| """ |
| For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from |
| @images. |
| |
| Args: |
| images (torch.Tensor): batch of images of shape [..., C, H, W] |
| |
| crop_height (int): height of crop to take |
| |
| crop_width (int): width of crop to take |
| |
| num_crops (n): number of crops to sample |
| |
| pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial |
| encoding of the original source pixel locations. This means that the |
| output crops will contain information about where in the source image |
| it was sampled from. |
| |
| Returns: |
| crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width) |
| if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width) |
| |
| crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2) |
| """ |
| device = images.device |
|
|
| |
| source_im = images |
| if pos_enc: |
| |
| h, w = source_im.shape[-2:] |
| pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w)) |
| pos_y = pos_y.float().to(device) / float(h) |
| pos_x = pos_x.float().to(device) / float(w) |
| position_enc = torch.stack((pos_y, pos_x)) |
|
|
| |
| leading_shape = source_im.shape[:-3] |
| position_enc = position_enc[(None,) * len(leading_shape)] |
| position_enc = position_enc.expand(*leading_shape, -1, -1, -1) |
|
|
| |
| source_im = torch.cat((source_im, position_enc), dim=-3) |
|
|
| |
| image_c, image_h, image_w = source_im.shape[-3:] |
| max_sample_h = image_h - crop_height |
| max_sample_w = image_w - crop_width |
|
|
| |
| |
| |
| |
| |
| |
| crop_inds_h = (max_sample_h * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long() |
| crop_inds_w = (max_sample_w * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long() |
| crop_inds = torch.cat((crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)), dim=-1) |
|
|
| crops = crop_image_from_indices( |
| images=source_im, |
| crop_indices=crop_inds, |
| crop_height=crop_height, |
| crop_width=crop_width, |
| ) |
|
|
| return crops, crop_inds |
|
|