File size: 1,460 Bytes
8652b14 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | import torch
import torch.nn as nn
import torch.nn.functional as F
class PosePredictionNet(nn.Module):
def __init__(self, img_channels=16, img_feat_dim=256, pose_dim=5, action_dim=25, hidden_dim=128):
super(PosePredictionNet, self).__init__()
self.cnn = nn.Sequential(
nn.Conv2d(img_channels, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1, 1))
)
self.fc_img = nn.Linear(128, img_feat_dim)
self.mlp_motion = nn.Sequential(
nn.Linear(pose_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
self.fc_out = nn.Sequential(
nn.Linear(img_feat_dim + hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, pose_dim)
)
def forward(self, img, action, pose):
img_feat = self.cnn(img).view(img.size(0), -1)
img_feat = self.fc_img(img_feat)
motion_feat = self.mlp_motion(torch.cat([pose, action], dim=1))
fused_feat = torch.cat([img_feat, motion_feat], dim=1)
pose_next_pred = self.fc_out(fused_feat)
return pose_next_pred |