doduo / matching.py
stevetod's picture
Upload model (#1)
5189ac9
import torch
import torch.nn.functional as F
from .geometry import coords_grid, generate_window_grid, normalize_coords
def global_correlation_softmax_prototype(
feature0,
feature1,
value,
pred_bidir_flow=False,
corr_mask=None,
):
"""
feature0: [B, C, H, W]
feature1: [B, C, H, W]
value: [B, C1, H, W]
corr_mask: [B, H*W, H*W] or None, if not None, the value will be masked out
"""
# global correlation
b, c, h, w = feature0.shape
c_value = value.size(1)
value = value.view(b, c_value, -1).permute(0, 2, 1) # [B, H*W, C1]
feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
feature1 = feature1.view(b, c, -1) # [B, C, H*W]
correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (
c**0.5
) # [B, H, W, H, W]
correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
if pred_bidir_flow:
correlation = torch.cat(
(correlation, correlation.permute(0, 2, 1)), dim=0
) # [2*B, H*W, H*W]
value = value.repeat(2, 1, 1) # [2*B, H*W, 2]
b = b * 2
if corr_mask is not None:
# mask out the correlation with corr_mask
if corr_mask.dtype == torch.bool:
# binary mask
correlation[corr_mask] = -65504.0
prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
else:
# float mask
# bin_mask = corr_mask < 0
# correlation[bin_mask] = -65504.0
prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
# soft_mask = torch.clamp(corr_mask, min=0)
prob = prob * corr_mask
# normalize
prob = prob / (prob.sum(dim=2, keepdim=True) + 1e-8)
else:
prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
# if corr_mask.dtype == torch.bool:
# # binary mask
# bin_mask = corr_mask
# num_true = bin_mask.sum(-1).unique().item()
# soft_mask = torch.ones((b, h * w, num_true)).to(corr_mask.device)
# else:
# # float mask
# bin_mask = corr_mask >= 0
# num_true = bin_mask.sum(-1).unique().item()
# soft_mask = corr_mask[bin_mask].view(b, h * w, num_true)
# assert soft_mask.min() >= 0
# correlation_seleted = correlation[bin_mask].view(b, h * w, num_true)
# prob_seleted = F.softmax(correlation_seleted, dim=-1) # [B, H*W, num_true]
# prob_seleted = soft_mask * prob_seleted
# prob_seleted = prob_seleted / (prob_seleted.sum(dim=2, keepdim=True) + 1e-8)
# prob = torch.zeros_like(correlation)
# prob[bin_mask] = prob_seleted.view(-1)
# else:
# prob = F.softmax(correlation, dim=-1)
result = torch.matmul(prob, value).view(b, h, w, c_value).permute(0, 3, 1, 2) # [B, 2, H, W]\
return result, correlation
def local_correlation_softmax_prototype(
feature0,
feature1,
value,
radius=5,
pred_bidir_flow=False,
corr_mask=None,
):
"""
softmax around argmax point
feature0: [B, C, H, W]
feature1: [B, C, H, W]
value: [B, C1, H, W]
corr_mask: [B, H*W, H*W] or None, if not None, the value will be masked out
"""
b, c, h, w = feature0.shape
c_value = value.size(1)
value = value.view(b, c_value, -1).permute(0, 2, 1) # [B, H*W, C1]
feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
feature1 = feature1.view(b, c, -1) # [B, C, H*W]
correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (
c**0.5
) # [B, H, W, H, W]
correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
if pred_bidir_flow:
correlation = torch.cat(
(correlation, correlation.permute(0, 2, 1)), dim=0
) # [2*B, H*W, H*W]
value = value.repeat(2, 1, 1) # [2*B, H*W, 2]
b = b * 2
if corr_mask is not None:
# mask out the correlation with corr_mask
if corr_mask.dtype == torch.bool:
# binary mask
correlation[corr_mask] = -65504.0
prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
else:
# float mask
# bin_mask = corr_mask < 0
# correlation[bin_mask] = -65504.0
prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
# soft_mask = torch.clamp(corr_mask, min=0)
prob = prob * corr_mask
else:
prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
# get local prob
# B, H*W, 2
coords = coords_grid(b, h, w, device=feature0.device).flatten(2).permute(0, 2, 1)
# B, H*W, H*W, 2
coords = coords.unsqueeze(1).repeat(1, h * w, 1, 1)
# B, H*W
argmax_pos = torch.argmax(prob, dim=2)
# B, H*W, 1, 2
argmax_pos = argmax_pos.view(b, h * w, 1, 1).repeat(1, 1, 1, 2)
# B, H*W, 1, 2
pos = torch.gather(coords, 2, argmax_pos)
# B, H*W, H*W
valid = ((coords - pos).square().sum(dim=-1) < (radius**2)).float()
prob = prob * valid
# normalize
prob = prob / (prob.sum(dim=2, keepdim=True) + 1e-8)
result = torch.matmul(prob, value).view(b, h, w, c_value).permute(0, 3, 1, 2) # [B, 2, H, W]\
return result, correlation
def global_correlation_softmax(
feature0,
feature1,
pred_bidir_flow=False,
):
# global correlation
b, c, h, w = feature0.shape
feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
feature1 = feature1.view(b, c, -1) # [B, C, H*W]
correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (
c**0.5
) # [B, H, W, H, W]
# flow from softmax
init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W]
grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
if pred_bidir_flow:
correlation = torch.cat(
(correlation, correlation.permute(0, 2, 1)), dim=0
) # [2*B, H*W, H*W]
init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W]
grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2]
b = b * 2
prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
# when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow
flow = correspondence - init_grid
return flow, prob
def local_correlation_softmax(
feature0,
feature1,
local_radius,
padding_mode="zeros",
):
b, c, h, w = feature0.size()
coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
local_h = 2 * local_radius + 1
local_w = 2 * local_radius + 1
window_grid = generate_window_grid(
-local_radius,
local_radius,
-local_radius,
local_radius,
local_h,
local_w,
device=feature0.device,
) # [2R+1, 2R+1, 2]
window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2]
sample_coords_softmax = sample_coords
# exclude coords that are out of image space
valid_x = (sample_coords[:, :, :, 0] >= 0) & (
sample_coords[:, :, :, 0] < w
) # [B, H*W, (2R+1)^2]
valid_y = (sample_coords[:, :, :, 1] >= 0) & (
sample_coords[:, :, :, 1] < h
) # [B, H*W, (2R+1)^2]
valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
# normalize coordinates to [-1, 1]
sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
window_feature = F.grid_sample(
feature1, sample_coords_norm, padding_mode=padding_mode, align_corners=False
).permute(
0, 2, 1, 3
) # [B, H*W, C, (2R+1)^2]
feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (
c**0.5
) # [B, H*W, (2R+1)^2]
# mask invalid locations
corr[~valid] = -1e9
prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2]
correspondence = (
torch.matmul(prob.unsqueeze(-2), sample_coords_softmax)
.squeeze(-2)
.view(b, h, w, 2)
.permute(0, 3, 1, 2)
) # [B, 2, H, W]
flow = correspondence - coords_init
match_prob = prob
return flow, match_prob
def local_correlation_with_flow(
feature0,
feature1,
flow,
local_radius,
padding_mode="zeros",
dilation=1,
):
b, c, h, w = feature0.size()
coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
local_h = 2 * local_radius + 1
local_w = 2 * local_radius + 1
window_grid = generate_window_grid(
-local_radius,
local_radius,
-local_radius,
local_radius,
local_h,
local_w,
device=feature0.device,
) # [2R+1, 2R+1, 2]
window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
sample_coords = coords.unsqueeze(-2) + window_grid * dilation # [B, H*W, (2R+1)^2, 2]
# flow can be zero when using features after transformer
if not isinstance(flow, float):
sample_coords = sample_coords + flow.view(b, 2, -1).permute(0, 2, 1).unsqueeze(
-2
) # [B, H*W, (2R+1)^2, 2]
else:
assert flow == 0.0
# normalize coordinates to [-1, 1]
sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
window_feature = F.grid_sample(
feature1, sample_coords_norm, padding_mode=padding_mode, align_corners=False
).permute(
0, 2, 1, 3
) # [B, H*W, C, (2R+1)^2]
feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (
c**0.5
) # [B, H*W, (2R+1)^2]
corr = corr.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous() # [B, (2R+1)^2, H, W]
return corr
def global_correlation_softmax_stereo(
feature0,
feature1,
):
# global correlation on horizontal direction
b, c, h, w = feature0.shape
x_grid = torch.linspace(0, w - 1, w, device=feature0.device) # [W]
feature0 = feature0.permute(0, 2, 3, 1) # [B, H, W, C]
feature1 = feature1.permute(0, 2, 1, 3) # [B, H, C, W]
correlation = torch.matmul(feature0, feature1) / (c**0.5) # [B, H, W, W]
# mask subsequent positions to make disparity positive
mask = torch.triu(torch.ones((w, w)), diagonal=1).type_as(feature0) # [W, W]
valid_mask = (mask == 0).unsqueeze(0).unsqueeze(0).repeat(b, h, 1, 1) # [B, H, W, W]
correlation[~valid_mask] = -1e9
prob = F.softmax(correlation, dim=-1) # [B, H, W, W]
correspondence = (x_grid.view(1, 1, 1, w) * prob).sum(-1) # [B, H, W]
# NOTE: unlike flow, disparity is typically positive
disparity = x_grid.view(1, 1, w).repeat(b, h, 1) - correspondence # [B, H, W]
return disparity.unsqueeze(1), prob # feature resolution
def local_correlation_softmax_stereo(
feature0,
feature1,
local_radius,
):
b, c, h, w = feature0.size()
coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
coords = coords_init.view(b, 2, -1).permute(0, 2, 1).contiguous() # [B, H*W, 2]
local_h = 1
local_w = 2 * local_radius + 1
window_grid = generate_window_grid(
0, 0, -local_radius, local_radius, local_h, local_w, device=feature0.device
) # [1, 2R+1, 2]
window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1), 2]
sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1), 2]
sample_coords_softmax = sample_coords
# exclude coords that are out of image space
valid_x = (sample_coords[:, :, :, 0] >= 0) & (
sample_coords[:, :, :, 0] < w
) # [B, H*W, (2R+1)^2]
valid_y = (sample_coords[:, :, :, 1] >= 0) & (
sample_coords[:, :, :, 1] < h
) # [B, H*W, (2R+1)^2]
valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
# normalize coordinates to [-1, 1]
sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
window_feature = F.grid_sample(
feature1, sample_coords_norm, padding_mode="zeros", align_corners=False
).permute(
0, 2, 1, 3
) # [B, H*W, C, (2R+1)]
feature0_view = (
feature0.permute(0, 2, 3, 1).contiguous().view(b, h * w, 1, c)
) # [B, H*W, 1, C]
corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (
c**0.5
) # [B, H*W, (2R+1)]
# mask invalid locations
corr[~valid] = -1e9
prob = F.softmax(corr, -1) # [B, H*W, (2R+1)]
correspondence = (
torch.matmul(prob.unsqueeze(-2), sample_coords_softmax)
.squeeze(-2)
.view(b, h, w, 2)
.permute(0, 3, 1, 2)
.contiguous()
) # [B, 2, H, W]
flow = correspondence - coords_init # flow at feature resolution
match_prob = prob
flow_x = -flow[:, :1] # [B, 1, H, W]
return flow_x, match_prob
def correlation_softmax_depth(
feature0,
feature1,
intrinsics,
pose,
depth_candidates,
depth_from_argmax=False,
pred_bidir_depth=False,
):
b, c, h, w = feature0.size()
assert depth_candidates.dim() == 4 # [B, D, H, W]
scale_factor = c**0.5
if pred_bidir_depth:
feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat(
(feature1, feature0), dim=0
)
intrinsics = intrinsics.repeat(2, 1, 1)
pose = torch.cat((pose, torch.inverse(pose)), dim=0)
depth_candidates = depth_candidates.repeat(2, 1, 1, 1)
# depth candidates are actually inverse depth
warped_feature1 = warp_with_pose_depth_candidates(
feature1,
intrinsics,
pose,
1.0 / depth_candidates,
) # [B, C, D, H, W]
correlation = (feature0.unsqueeze(2) * warped_feature1).sum(1) / scale_factor # [B, D, H, W]
match_prob = F.softmax(correlation, dim=1) # [B, D, H, W]
# for cross-task transfer (flow -> depth), extract depth with argmax at test time
if depth_from_argmax:
index = torch.argmax(match_prob, dim=1, keepdim=True)
depth = torch.gather(depth_candidates, dim=1, index=index)
else:
depth = (match_prob * depth_candidates).sum(dim=1, keepdim=True) # [B, 1, H, W]
return depth, match_prob
def warp_with_pose_depth_candidates(
feature1,
intrinsics,
pose,
depth,
clamp_min_depth=1e-3,
):
"""
feature1: [B, C, H, W]
intrinsics: [B, 3, 3]
pose: [B, 4, 4]
depth: [B, D, H, W]
"""
assert intrinsics.size(1) == intrinsics.size(2) == 3
assert pose.size(1) == pose.size(2) == 4
assert depth.dim() == 4
b, d, h, w = depth.size()
c = feature1.size(1)
with torch.no_grad():
# pixel coordinates
grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W]
# back project to 3D and transform viewpoint
points = torch.inverse(intrinsics).bmm(grid.view(b, 3, -1)) # [B, 3, H*W]
points = torch.bmm(pose[:, :3, :3], points).unsqueeze(2).repeat(1, 1, d, 1) * depth.view(
b, 1, d, h * w
) # [B, 3, D, H*W]
points = points + pose[:, :3, -1:].unsqueeze(-1) # [B, 3, D, H*W]
# reproject to 2D image plane
points = torch.bmm(intrinsics, points.view(b, 3, -1)).view(
b, 3, d, h * w
) # [B, 3, D, H*W]
pixel_coords = points[:, :2] / points[:, -1:].clamp(min=clamp_min_depth) # [B, 2, D, H*W]
# normalize to [-1, 1]
x_grid = 2 * pixel_coords[:, 0] / (w - 1) - 1
y_grid = 2 * pixel_coords[:, 1] / (h - 1) - 1
grid = torch.stack([x_grid, y_grid], dim=-1) # [B, D, H*W, 2]
# sample features
warped_feature = F.grid_sample(
feature1,
grid.view(b, d * h, w, 2),
mode="bilinear",
padding_mode="zeros",
align_corners=False,
).view(
b, c, d, h, w
) # [B, C, D, H, W]
return warped_feature