|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class CameraDec(nn.Module): |
|
|
def __init__(self, dim_in=1536): |
|
|
super().__init__() |
|
|
output_dim = dim_in |
|
|
self.backbone = nn.Sequential( |
|
|
nn.Linear(output_dim, output_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(output_dim, output_dim), |
|
|
nn.ReLU(), |
|
|
) |
|
|
self.fc_t = nn.Linear(output_dim, 3) |
|
|
self.fc_qvec = nn.Linear(output_dim, 4) |
|
|
self.fc_fov = nn.Sequential(nn.Linear(output_dim, 2), nn.ReLU()) |
|
|
|
|
|
def forward(self, feat, camera_encoding=None, *args, **kwargs): |
|
|
B, N = feat.shape[:2] |
|
|
feat = feat.reshape(B * N, -1) |
|
|
feat = self.backbone(feat) |
|
|
out_t = self.fc_t(feat.float()).reshape(B, N, 3) |
|
|
if camera_encoding is None: |
|
|
out_qvec = self.fc_qvec(feat.float()).reshape(B, N, 4) |
|
|
out_fov = self.fc_fov(feat.float()).reshape(B, N, 2) |
|
|
else: |
|
|
out_qvec = camera_encoding[..., 3:7] |
|
|
out_fov = camera_encoding[..., -2:] |
|
|
pose_enc = torch.cat([out_t, out_qvec, out_fov], dim=-1) |
|
|
return pose_enc |
|
|
|