| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class CameraDec(nn.Module): |
| def __init__(self, dim_in=1536): |
| super().__init__() |
| output_dim = dim_in |
| self.backbone = nn.Sequential( |
| nn.Linear(output_dim, output_dim), |
| nn.ReLU(), |
| nn.Linear(output_dim, output_dim), |
| nn.ReLU(), |
| ) |
| self.fc_t = nn.Linear(output_dim, 3) |
| self.fc_qvec = nn.Linear(output_dim, 4) |
| self.fc_fov = nn.Sequential(nn.Linear(output_dim, 2), nn.ReLU()) |
|
|
| def forward(self, feat, camera_encoding=None, *args, **kwargs): |
| B, N = feat.shape[:2] |
| feat = feat.reshape(B * N, -1).to(dtype=self.backbone[0].weight.dtype) |
| feat = self.backbone(feat) |
| out_t = self.fc_t(feat.to(dtype=self.fc_t.weight.dtype)).reshape(B, N, 3) |
| if camera_encoding is None: |
| out_qvec = self.fc_qvec(feat.to(dtype=self.fc_qvec.weight.dtype)).reshape(B, N, 4) |
| out_fov = self.fc_fov(feat.to(dtype=self.fc_fov[0].weight.dtype)).reshape(B, N, 2) |
| else: |
| out_qvec = camera_encoding[..., 3:7] |
| out_fov = camera_encoding[..., -2:] |
| pose_enc = torch.cat([out_t, out_qvec, out_fov], dim=-1) |
| return pose_enc |
|
|