import torch import torch.nn as nn from diffusion_policy.model.vision.model_getter import get_resnet from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder class ResnetDP(nn.Module): def __init__(self, shape_meta: dict): super().__init__() rgb_model = get_resnet(name="resnet18") self.backbone = MultiImageObsEncoder( shape_meta=shape_meta, rgb_model=rgb_model, crop_shape=(76, 76), random_crop=True, use_group_norm=True, share_rgb_model=False, imagenet_norm=True, ) return def forward(self, images: torch.Tensor, robot_state_obs: torch.Tensor = None) -> torch.Tensor: B = images.shape[0] # Flatten the batch and time dimensions images = images.reshape(-1, *images.shape[2:]).permute(0, 1, 4, 2, 3) robot_state_obs = robot_state_obs.float().reshape(-1, *robot_state_obs.shape[2:]) # Encode all observations (across time steps and batch size) obs_dict = {f"img_{i}": images[:, i] for i in range(images.shape[1])} obs_dict["robot_state"] = robot_state_obs nx = self.backbone(obs_dict) # Reshape back to the batch dimension. Now the features of each time step are concatenated nx = nx.reshape(B, -1) return nx