File size: 1,460 Bytes
8652b14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn
import torch.nn.functional as F

class PosePredictionNet(nn.Module):
    def __init__(self, img_channels=16, img_feat_dim=256, pose_dim=5, action_dim=25, hidden_dim=128):
        super(PosePredictionNet, self).__init__()
        
        self.cnn = nn.Sequential(
            nn.Conv2d(img_channels, 32, kernel_size=3, stride=2, padding=1), 
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)) 
        )
        
        self.fc_img = nn.Linear(128, img_feat_dim)  
        
        self.mlp_motion = nn.Sequential(
            nn.Linear(pose_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        self.fc_out = nn.Sequential(
            nn.Linear(img_feat_dim + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, pose_dim)
        )

    def forward(self, img, action, pose):
        img_feat = self.cnn(img).view(img.size(0), -1)
        img_feat = self.fc_img(img_feat)
        
        motion_feat = self.mlp_motion(torch.cat([pose, action], dim=1))
        fused_feat = torch.cat([img_feat, motion_feat], dim=1)
        pose_next_pred = self.fc_out(fused_feat)
        
        return pose_next_pred