SteEsp's picture
Add Docker-based Learn2Splat demo (viser GUI)
78d2329 verified
import torch
import torch.nn.functional as F
def coords_grid(b, h, w, homogeneous=False, device=None):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
stacks = [x, y]
if homogeneous:
ones = torch.ones_like(x) # [H, W]
stacks.append(ones)
grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
if device is not None:
grid = grid.to(device)
return grid
def warp_with_pose_depth_candidates(
feature1,
intrinsics,
pose,
depth,
clamp_min_depth=1e-3,
grid_sample_disable_cudnn=False,
):
"""
feature1: [B, C, H, W]
intrinsics: [B, 3, 3]
pose: [B, 4, 4]
depth: [B, D, H, W]
"""
assert intrinsics.size(1) == intrinsics.size(2) == 3
assert pose.size(1) == pose.size(2) == 4
assert depth.dim() == 4
b, d, h, w = depth.size()
c = feature1.size(1)
with torch.no_grad():
# pixel coordinates
grid = coords_grid(
b, h, w, homogeneous=True, device=depth.device
) # [B, 3, H, W]
# back project to 3D and transform viewpoint
points = torch.inverse(intrinsics).bmm(grid.view(b, 3, -1)) # [B, 3, H*W]
points = torch.bmm(pose[:, :3, :3], points).unsqueeze(2).repeat(
1, 1, d, 1
) * depth.view(
b, 1, d, h * w
) # [B, 3, D, H*W]
points = points + pose[:, :3, -1:].unsqueeze(-1) # [B, 3, D, H*W]
# reproject to 2D image plane
points = torch.bmm(intrinsics, points.view(b, 3, -1)).view(
b, 3, d, h * w
) # [B, 3, D, H*W]
pixel_coords = points[:, :2] / points[:, -1:].clamp(
min=clamp_min_depth
) # [B, 2, D, H*W]
# normalize to [-1, 1]
x_grid = 2 * pixel_coords[:, 0] / (w - 1) - 1
y_grid = 2 * pixel_coords[:, 1] / (h - 1) - 1
grid = torch.stack([x_grid, y_grid], dim=-1) # [B, D, H*W, 2]
# sample features
# ref: https://github.com/pytorch/pytorch/issues/88380
# print(feature1.shape, grid.shape)
# hardcoded workaround
if feature1.numel() > 1000000:
grid_sample_disable_cudnn = True
with torch.backends.cudnn.flags(enabled=not grid_sample_disable_cudnn):
warped_feature = F.grid_sample(
feature1,
grid.view(b, d * h, w, 2),
mode="bilinear",
padding_mode="zeros",
align_corners=True,
).view(
b, c, d, h, w
) # [B, C, D, H, W]
return warped_feature