| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class PosePredictionNet(nn.Module): |
| def __init__(self, img_channels=16, img_feat_dim=256, pose_dim=5, action_dim=25, hidden_dim=128): |
| super(PosePredictionNet, self).__init__() |
| |
| self.cnn = nn.Sequential( |
| nn.Conv2d(img_channels, 32, kernel_size=3, stride=2, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), |
| nn.ReLU(), |
| nn.AdaptiveAvgPool2d((1, 1)) |
| ) |
| |
| self.fc_img = nn.Linear(128, img_feat_dim) |
| |
| self.mlp_motion = nn.Sequential( |
| nn.Linear(pose_dim + action_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.ReLU() |
| ) |
| |
| self.fc_out = nn.Sequential( |
| nn.Linear(img_feat_dim + hidden_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, pose_dim) |
| ) |
|
|
| def forward(self, img, action, pose): |
| img_feat = self.cnn(img).view(img.size(0), -1) |
| img_feat = self.fc_img(img_feat) |
| |
| motion_feat = self.mlp_motion(torch.cat([pose, action], dim=1)) |
| fused_feat = torch.cat([img_feat, motion_feat], dim=1) |
| pose_next_pred = self.fc_out(fused_feat) |
| |
| return pose_next_pred |