File size: 1,591 Bytes
912c7e2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | import torch
import torch.nn as nn
from diffusion_policy_3d.model.vision.pointnet_extractor import (
PointNetEncoderXYZRGB,
PointNetEncoderXYZ,
)
class MLP3DP(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
if in_channels == 3:
self.backbone = PointNetEncoderXYZ(
in_channels=in_channels,
out_channels=out_channels,
use_layernorm=True,
final_norm="layernorm",
normal_channel=False,
)
elif in_channels == 6:
self.backbone = PointNetEncoderXYZRGB(
in_channels=in_channels,
out_channels=out_channels,
use_layernorm=True,
final_norm="layernorm",
normal_channel=False,
)
else:
raise ValueError("Invalid number of input channels for MLP3DP")
return
def forward(self, pcd: torch.Tensor, robot_state_obs: torch.Tensor = None) -> torch.Tensor:
B = pcd.shape[0]
# Flatten the batch and time dimensions
pcd = pcd.float().reshape(-1, *pcd.shape[2:])
robot_state_obs = robot_state_obs.float().reshape(-1, *robot_state_obs.shape[2:])
# Encode all point clouds (across time steps and batch size)
encoded_pcd = self.backbone(pcd)
nx = torch.cat([encoded_pcd, robot_state_obs], dim=1)
# Reshape back to the batch dimension. Now the features of each time step are concatenated
nx = nx.reshape(B, -1)
return nx
|