lsnu's picture
Add files using upload-large-folder tool
912c7e2 verified
import torch
import torch.nn as nn
from diffusion_policy_3d.model.vision.pointnet_extractor import (
PointNetEncoderXYZRGB,
PointNetEncoderXYZ,
)
class MLP3DP(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
if in_channels == 3:
self.backbone = PointNetEncoderXYZ(
in_channels=in_channels,
out_channels=out_channels,
use_layernorm=True,
final_norm="layernorm",
normal_channel=False,
)
elif in_channels == 6:
self.backbone = PointNetEncoderXYZRGB(
in_channels=in_channels,
out_channels=out_channels,
use_layernorm=True,
final_norm="layernorm",
normal_channel=False,
)
else:
raise ValueError("Invalid number of input channels for MLP3DP")
return
def forward(self, pcd: torch.Tensor, robot_state_obs: torch.Tensor = None) -> torch.Tensor:
B = pcd.shape[0]
# Flatten the batch and time dimensions
pcd = pcd.float().reshape(-1, *pcd.shape[2:])
robot_state_obs = robot_state_obs.float().reshape(-1, *robot_state_obs.shape[2:])
# Encode all point clouds (across time steps and batch size)
encoded_pcd = self.backbone(pcd)
nx = torch.cat([encoded_pcd, robot_state_obs], dim=1)
# Reshape back to the batch dimension. Now the features of each time step are concatenated
nx = nx.reshape(B, -1)
return nx