WorldMem_Repro / algorithms /worldmem /models /pose_prediction.py
BonanDing's picture
update lfs
8652b14
import torch
import torch.nn as nn
import torch.nn.functional as F
class PosePredictionNet(nn.Module):
def __init__(self, img_channels=16, img_feat_dim=256, pose_dim=5, action_dim=25, hidden_dim=128):
super(PosePredictionNet, self).__init__()
self.cnn = nn.Sequential(
nn.Conv2d(img_channels, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1, 1))
)
self.fc_img = nn.Linear(128, img_feat_dim)
self.mlp_motion = nn.Sequential(
nn.Linear(pose_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
self.fc_out = nn.Sequential(
nn.Linear(img_feat_dim + hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, pose_dim)
)
def forward(self, img, action, pose):
img_feat = self.cnn(img).view(img.size(0), -1)
img_feat = self.fc_img(img_feat)
motion_feat = self.mlp_motion(torch.cat([pose, action], dim=1))
fused_feat = torch.cat([img_feat, motion_feat], dim=1)
pose_next_pred = self.fc_out(fused_feat)
return pose_next_pred