import torch import torch.nn as nn from diffusion_policy_3d.model.vision.pointnet_extractor import ( PointNetEncoderXYZRGB, PointNetEncoderXYZ, ) class MLP3DP(nn.Module): def __init__(self, in_channels: int, out_channels: int): super().__init__() if in_channels == 3: self.backbone = PointNetEncoderXYZ( in_channels=in_channels, out_channels=out_channels, use_layernorm=True, final_norm="layernorm", normal_channel=False, ) elif in_channels == 6: self.backbone = PointNetEncoderXYZRGB( in_channels=in_channels, out_channels=out_channels, use_layernorm=True, final_norm="layernorm", normal_channel=False, ) else: raise ValueError("Invalid number of input channels for MLP3DP") return def forward(self, pcd: torch.Tensor, robot_state_obs: torch.Tensor = None) -> torch.Tensor: B = pcd.shape[0] # Flatten the batch and time dimensions pcd = pcd.float().reshape(-1, *pcd.shape[2:]) robot_state_obs = robot_state_obs.float().reshape(-1, *robot_state_obs.shape[2:]) # Encode all point clouds (across time steps and batch size) encoded_pcd = self.backbone(pcd) nx = torch.cat([encoded_pcd, robot_state_obs], dim=1) # Reshape back to the batch dimension. Now the features of each time step are concatenated nx = nx.reshape(B, -1) return nx