File size: 1,362 Bytes
912c7e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
from diffusion_policy.model.vision.model_getter import get_resnet
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder


class ResnetDP(nn.Module):
    def __init__(self, shape_meta: dict):
        super().__init__()
        rgb_model = get_resnet(name="resnet18")
        self.backbone = MultiImageObsEncoder(
            shape_meta=shape_meta,
            rgb_model=rgb_model,
            crop_shape=(76, 76),
            random_crop=True,
            use_group_norm=True,
            share_rgb_model=False,
            imagenet_norm=True,
        )
        return

    def forward(self, images: torch.Tensor, robot_state_obs: torch.Tensor = None) -> torch.Tensor:
        B = images.shape[0]
        # Flatten the batch and time dimensions
        images = images.reshape(-1, *images.shape[2:]).permute(0, 1, 4, 2, 3)
        robot_state_obs = robot_state_obs.float().reshape(-1, *robot_state_obs.shape[2:])
        # Encode all observations (across time steps and batch size)
        obs_dict = {f"img_{i}": images[:, i] for i in range(images.shape[1])}
        obs_dict["robot_state"] = robot_state_obs
        nx = self.backbone(obs_dict)
        # Reshape back to the batch dimension. Now the features of each time step are concatenated
        nx = nx.reshape(B, -1)
        return nx