File size: 1,591 Bytes
912c7e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn
from diffusion_policy_3d.model.vision.pointnet_extractor import (
    PointNetEncoderXYZRGB,
    PointNetEncoderXYZ,
)


class MLP3DP(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        if in_channels == 3:
            self.backbone = PointNetEncoderXYZ(
                in_channels=in_channels,
                out_channels=out_channels,
                use_layernorm=True,
                final_norm="layernorm",
                normal_channel=False,
            )
        elif in_channels == 6:
            self.backbone = PointNetEncoderXYZRGB(
                in_channels=in_channels,
                out_channels=out_channels,
                use_layernorm=True,
                final_norm="layernorm",
                normal_channel=False,
            )
        else:
            raise ValueError("Invalid number of input channels for MLP3DP")
        return

    def forward(self, pcd: torch.Tensor, robot_state_obs: torch.Tensor = None) -> torch.Tensor:
        B = pcd.shape[0]
        # Flatten the batch and time dimensions
        pcd = pcd.float().reshape(-1, *pcd.shape[2:])
        robot_state_obs = robot_state_obs.float().reshape(-1, *robot_state_obs.shape[2:])
        # Encode all point clouds (across time steps and batch size)
        encoded_pcd = self.backbone(pcd)
        nx = torch.cat([encoded_pcd, robot_state_obs], dim=1)
        # Reshape back to the batch dimension. Now the features of each time step are concatenated
        nx = nx.reshape(B, -1)
        return nx