Spaces:
Running
Running
unimatch demo
Browse files- app.py +160 -0
- dataloader/__init__.py +0 -0
- dataloader/stereo/transforms.py +434 -0
- demo/flow_davis_skate-jump_00059.jpg +0 -0
- demo/flow_davis_skate-jump_00060.jpg +0 -0
- demo/flow_kitti_test_000197_10.png +0 -0
- demo/flow_kitti_test_000197_11.png +0 -0
- demo/flow_sintel_cave_3_frame_0049.png +0 -0
- demo/flow_sintel_cave_3_frame_0050.png +0 -0
- demo/stereo_drivingstereo_test_2018-07-11-14-48-52_2018-07-11-14-58-34-673_left.jpg +0 -0
- demo/stereo_drivingstereo_test_2018-07-11-14-48-52_2018-07-11-14-58-34-673_right.jpg +0 -0
- pretrained/tmp.txt +0 -0
- requirements.txt +5 -0
- unimatch/__init__.py +0 -0
- unimatch/attention.py +253 -0
- unimatch/backbone.py +117 -0
- unimatch/geometry.py +195 -0
- unimatch/matching.py +279 -0
- unimatch/position.py +46 -0
- unimatch/reg_refine.py +119 -0
- unimatch/transformer.py +294 -0
- unimatch/trident_conv.py +90 -0
- unimatch/unimatch.py +367 -0
- unimatch/utils.py +216 -0
- utils/flow_viz.py +290 -0
- utils/visualization.py +110 -0
app.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
from unimatch.unimatch import UniMatch
|
| 8 |
+
from utils.flow_viz import flow_to_image
|
| 9 |
+
from dataloader.stereo import transforms
|
| 10 |
+
from utils.visualization import vis_disparity
|
| 11 |
+
|
| 12 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 13 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@torch.no_grad()
|
| 17 |
+
def inference(image1, image2, task='flow'):
|
| 18 |
+
"""Inference on an image pair for optical flow or stereo disparity prediction"""
|
| 19 |
+
|
| 20 |
+
model = UniMatch(feature_channels=128,
|
| 21 |
+
num_scales=2,
|
| 22 |
+
upsample_factor=4,
|
| 23 |
+
ffn_dim_expansion=4,
|
| 24 |
+
num_transformer_layers=6,
|
| 25 |
+
reg_refine=True,
|
| 26 |
+
task=task)
|
| 27 |
+
|
| 28 |
+
model.eval()
|
| 29 |
+
|
| 30 |
+
if task == 'flow':
|
| 31 |
+
checkpoint_path = 'pretrained/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth'
|
| 32 |
+
else:
|
| 33 |
+
checkpoint_path = 'pretrained/gmstereo-scale2-regrefine3-resumeflowthings-mixdata-train320x640-ft640x960-e4e291fd.pth'
|
| 34 |
+
|
| 35 |
+
checkpoint_flow = torch.load(checkpoint_path)
|
| 36 |
+
model.load_state_dict(checkpoint_flow['model'], strict=True)
|
| 37 |
+
|
| 38 |
+
padding_factor = 32
|
| 39 |
+
attn_type = 'swin' if task == 'flow' else 'self_swin2d_cross_swin1d'
|
| 40 |
+
attn_splits_list = [2, 8]
|
| 41 |
+
corr_radius_list = [-1, 4]
|
| 42 |
+
prop_radius_list = [-1, 1]
|
| 43 |
+
num_reg_refine = 6 if task == 'flow' else 3
|
| 44 |
+
|
| 45 |
+
# smaller inference size for faster speed
|
| 46 |
+
max_inference_size = [384, 768] if task == 'flow' else [640, 960]
|
| 47 |
+
|
| 48 |
+
transpose_img = False
|
| 49 |
+
|
| 50 |
+
image1 = np.array(image1).astype(np.float32)
|
| 51 |
+
image2 = np.array(image2).astype(np.float32)
|
| 52 |
+
|
| 53 |
+
if len(image1.shape) == 2: # gray image
|
| 54 |
+
image1 = np.tile(image1[..., None], (1, 1, 3))
|
| 55 |
+
image2 = np.tile(image2[..., None], (1, 1, 3))
|
| 56 |
+
else:
|
| 57 |
+
image1 = image1[..., :3]
|
| 58 |
+
image2 = image2[..., :3]
|
| 59 |
+
|
| 60 |
+
if task == 'flow':
|
| 61 |
+
image1 = torch.from_numpy(image1).permute(2, 0, 1).float().unsqueeze(0)
|
| 62 |
+
image2 = torch.from_numpy(image2).permute(2, 0, 1).float().unsqueeze(0)
|
| 63 |
+
else:
|
| 64 |
+
val_transform_list = [transforms.ToTensor(),
|
| 65 |
+
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
val_transform = transforms.Compose(val_transform_list)
|
| 69 |
+
|
| 70 |
+
sample = {'left': image1, 'right': image2}
|
| 71 |
+
sample = val_transform(sample)
|
| 72 |
+
|
| 73 |
+
image1 = sample['left'].unsqueeze(0) # [1, 3, H, W]
|
| 74 |
+
image2 = sample['right'].unsqueeze(0) # [1, 3, H, W]
|
| 75 |
+
|
| 76 |
+
# the model is trained with size: width > height
|
| 77 |
+
if task == 'flow' and image1.size(-2) > image1.size(-1):
|
| 78 |
+
image1 = torch.transpose(image1, -2, -1)
|
| 79 |
+
image2 = torch.transpose(image2, -2, -1)
|
| 80 |
+
transpose_img = True
|
| 81 |
+
|
| 82 |
+
nearest_size = [int(np.ceil(image1.size(-2) / padding_factor)) * padding_factor,
|
| 83 |
+
int(np.ceil(image1.size(-1) / padding_factor)) * padding_factor]
|
| 84 |
+
|
| 85 |
+
inference_size = [min(max_inference_size[0], nearest_size[0]), min(max_inference_size[1], nearest_size[1])]
|
| 86 |
+
|
| 87 |
+
assert isinstance(inference_size, list) or isinstance(inference_size, tuple)
|
| 88 |
+
ori_size = image1.shape[-2:]
|
| 89 |
+
|
| 90 |
+
# resize before inference
|
| 91 |
+
if inference_size[0] != ori_size[0] or inference_size[1] != ori_size[1]:
|
| 92 |
+
image1 = F.interpolate(image1, size=inference_size, mode='bilinear',
|
| 93 |
+
align_corners=True)
|
| 94 |
+
image2 = F.interpolate(image2, size=inference_size, mode='bilinear',
|
| 95 |
+
align_corners=True)
|
| 96 |
+
|
| 97 |
+
results_dict = model(image1, image2,
|
| 98 |
+
attn_type=attn_type,
|
| 99 |
+
attn_splits_list=attn_splits_list,
|
| 100 |
+
corr_radius_list=corr_radius_list,
|
| 101 |
+
prop_radius_list=prop_radius_list,
|
| 102 |
+
num_reg_refine=num_reg_refine,
|
| 103 |
+
task=task,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
flow_pr = results_dict['flow_preds'][-1] # [1, 2, H, W] or [1, H, W]
|
| 107 |
+
|
| 108 |
+
# resize back
|
| 109 |
+
if task == 'flow':
|
| 110 |
+
if inference_size[0] != ori_size[0] or inference_size[1] != ori_size[1]:
|
| 111 |
+
flow_pr = F.interpolate(flow_pr, size=ori_size, mode='bilinear',
|
| 112 |
+
align_corners=True)
|
| 113 |
+
flow_pr[:, 0] = flow_pr[:, 0] * ori_size[-1] / inference_size[-1]
|
| 114 |
+
flow_pr[:, 1] = flow_pr[:, 1] * ori_size[-2] / inference_size[-2]
|
| 115 |
+
else:
|
| 116 |
+
if inference_size[0] != ori_size[0] or inference_size[1] != ori_size[1]:
|
| 117 |
+
pred_disp = F.interpolate(flow_pr.unsqueeze(1), size=ori_size,
|
| 118 |
+
mode='bilinear',
|
| 119 |
+
align_corners=True).squeeze(1) # [1, H, W]
|
| 120 |
+
pred_disp = pred_disp * ori_size[-1] / float(inference_size[-1])
|
| 121 |
+
|
| 122 |
+
if task == 'flow':
|
| 123 |
+
if transpose_img:
|
| 124 |
+
flow_pr = torch.transpose(flow_pr, -2, -1)
|
| 125 |
+
|
| 126 |
+
flow = flow_pr[0].permute(1, 2, 0).cpu().numpy() # [H, W, 2]
|
| 127 |
+
|
| 128 |
+
output = flow_to_image(flow) # [H, W, 3]
|
| 129 |
+
else:
|
| 130 |
+
disp = pred_disp[0].cpu().numpy()
|
| 131 |
+
|
| 132 |
+
output = vis_disparity(disp, return_rgb=True)
|
| 133 |
+
|
| 134 |
+
return Image.fromarray(output)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
title = "UniMatch"
|
| 138 |
+
|
| 139 |
+
description = "<p style='text-align: center'>Optical flow and stereo matching demo for <a href='https://haofeixu.github.io/unimatch/' target='_blank'>Unifying Flow, Stereo and Depth Estimation</a> | <a href='https://arxiv.org/abs/2211.05783' target='_blank'>Paper</a> | <a href='https://github.com/autonomousvision/unimatch' target='_blank'>Code</a> | <a href='https://colab.research.google.com/drive/1r5m-xVy3Kw60U-m5VB-aQ98oqqg_6cab?usp=sharing' target='_blank'>Colab</a><br>Simply upload your images or click one of the provided examples.<br>The <strong>first three</strong> examples are video frames for <strong>flow</strong> task, and the <strong>last three</strong> are stereo pairs for <strong>stereo</strong> task.<br><strong>Select the task type according to your input images</strong>.</p>"
|
| 140 |
+
|
| 141 |
+
examples = [
|
| 142 |
+
['demo/flow_kitti_test_000197_10.png', 'demo/flow_kitti_test_000197_11.png'],
|
| 143 |
+
['demo/flow_sintel_cave_3_frame_0049.png', 'demo/flow_sintel_cave_3_frame_0050.png'],
|
| 144 |
+
['demo/flow_davis_skate-jump_00059.jpg', 'demo/flow_davis_skate-jump_00060.jpg'],
|
| 145 |
+
['demo/stereo_drivingstereo_test_2018-07-11-14-48-52_2018-07-11-14-58-34-673_left.jpg',
|
| 146 |
+
'demo/stereo_drivingstereo_test_2018-07-11-14-48-52_2018-07-11-14-58-34-673_right.jpg'],
|
| 147 |
+
['demo/stereo_middlebury_plants_im0.png', 'demo/stereo_middlebury_plants_im1.png'],
|
| 148 |
+
['demo/stereo_holopix_left.png', 'demo/stereo_holopix_right.png']
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
gr.Interface(
|
| 152 |
+
inference,
|
| 153 |
+
[gr.Image(type="pil", label="Image1"), gr.Image(type="pil", label="Image2"), gr.Radio(choices=['flow', 'stereo'], value='flow', label='Task')],
|
| 154 |
+
gr.Image(type="pil", label="Flow/Disparity"),
|
| 155 |
+
title=title,
|
| 156 |
+
description=description,
|
| 157 |
+
examples=examples,
|
| 158 |
+
thumbnail="https://haofeixu.github.io/unimatch/resources/teaser.svg",
|
| 159 |
+
allow_flagging="auto",
|
| 160 |
+
).launch(debug=True)
|
dataloader/__init__.py
ADDED
|
File without changes
|
dataloader/stereo/transforms.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import division
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torchvision.transforms.functional as F
|
| 6 |
+
import random
|
| 7 |
+
import cv2
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Compose(object):
|
| 11 |
+
def __init__(self, transforms):
|
| 12 |
+
self.transforms = transforms
|
| 13 |
+
|
| 14 |
+
def __call__(self, sample):
|
| 15 |
+
for t in self.transforms:
|
| 16 |
+
sample = t(sample)
|
| 17 |
+
return sample
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ToTensor(object):
|
| 21 |
+
"""Convert numpy array to torch tensor"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, no_normalize=False):
|
| 24 |
+
self.no_normalize = no_normalize
|
| 25 |
+
|
| 26 |
+
def __call__(self, sample):
|
| 27 |
+
left = np.transpose(sample['left'], (2, 0, 1)) # [3, H, W]
|
| 28 |
+
if self.no_normalize:
|
| 29 |
+
sample['left'] = torch.from_numpy(left)
|
| 30 |
+
else:
|
| 31 |
+
sample['left'] = torch.from_numpy(left) / 255.
|
| 32 |
+
right = np.transpose(sample['right'], (2, 0, 1))
|
| 33 |
+
|
| 34 |
+
if self.no_normalize:
|
| 35 |
+
sample['right'] = torch.from_numpy(right)
|
| 36 |
+
else:
|
| 37 |
+
sample['right'] = torch.from_numpy(right) / 255.
|
| 38 |
+
|
| 39 |
+
# disp = np.expand_dims(sample['disp'], axis=0) # [1, H, W]
|
| 40 |
+
if 'disp' in sample.keys():
|
| 41 |
+
disp = sample['disp'] # [H, W]
|
| 42 |
+
sample['disp'] = torch.from_numpy(disp)
|
| 43 |
+
|
| 44 |
+
return sample
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class Normalize(object):
|
| 48 |
+
"""Normalize image, with type tensor"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, mean, std):
|
| 51 |
+
self.mean = mean
|
| 52 |
+
self.std = std
|
| 53 |
+
|
| 54 |
+
def __call__(self, sample):
|
| 55 |
+
|
| 56 |
+
norm_keys = ['left', 'right']
|
| 57 |
+
|
| 58 |
+
for key in norm_keys:
|
| 59 |
+
# Images have converted to tensor, with shape [C, H, W]
|
| 60 |
+
for t, m, s in zip(sample[key], self.mean, self.std):
|
| 61 |
+
t.sub_(m).div_(s)
|
| 62 |
+
|
| 63 |
+
return sample
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class RandomCrop(object):
|
| 67 |
+
def __init__(self, img_height, img_width):
|
| 68 |
+
self.img_height = img_height
|
| 69 |
+
self.img_width = img_width
|
| 70 |
+
|
| 71 |
+
def __call__(self, sample):
|
| 72 |
+
ori_height, ori_width = sample['left'].shape[:2]
|
| 73 |
+
|
| 74 |
+
# pad zero when crop size is larger than original image size
|
| 75 |
+
if self.img_height > ori_height or self.img_width > ori_width:
|
| 76 |
+
|
| 77 |
+
# can be used for only pad one side
|
| 78 |
+
top_pad = max(self.img_height - ori_height, 0)
|
| 79 |
+
right_pad = max(self.img_width - ori_width, 0)
|
| 80 |
+
|
| 81 |
+
# try edge padding
|
| 82 |
+
sample['left'] = np.lib.pad(sample['left'],
|
| 83 |
+
((top_pad, 0), (0, right_pad), (0, 0)),
|
| 84 |
+
mode='edge')
|
| 85 |
+
sample['right'] = np.lib.pad(sample['right'],
|
| 86 |
+
((top_pad, 0), (0, right_pad), (0, 0)),
|
| 87 |
+
mode='edge')
|
| 88 |
+
|
| 89 |
+
if 'disp' in sample.keys():
|
| 90 |
+
sample['disp'] = np.lib.pad(sample['disp'],
|
| 91 |
+
((top_pad, 0), (0, right_pad)),
|
| 92 |
+
mode='constant',
|
| 93 |
+
constant_values=0)
|
| 94 |
+
|
| 95 |
+
# update image resolution
|
| 96 |
+
ori_height, ori_width = sample['left'].shape[:2]
|
| 97 |
+
|
| 98 |
+
assert self.img_height <= ori_height and self.img_width <= ori_width
|
| 99 |
+
|
| 100 |
+
# Training: random crop
|
| 101 |
+
self.offset_x = np.random.randint(ori_width - self.img_width + 1)
|
| 102 |
+
|
| 103 |
+
start_height = 0
|
| 104 |
+
assert ori_height - start_height >= self.img_height
|
| 105 |
+
|
| 106 |
+
self.offset_y = np.random.randint(start_height, ori_height - self.img_height + 1)
|
| 107 |
+
|
| 108 |
+
sample['left'] = self.crop_img(sample['left'])
|
| 109 |
+
sample['right'] = self.crop_img(sample['right'])
|
| 110 |
+
if 'disp' in sample.keys():
|
| 111 |
+
sample['disp'] = self.crop_img(sample['disp'])
|
| 112 |
+
|
| 113 |
+
return sample
|
| 114 |
+
|
| 115 |
+
def crop_img(self, img):
|
| 116 |
+
return img[self.offset_y:self.offset_y + self.img_height,
|
| 117 |
+
self.offset_x:self.offset_x + self.img_width]
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class RandomVerticalFlip(object):
|
| 121 |
+
"""Randomly vertically filps"""
|
| 122 |
+
|
| 123 |
+
def __call__(self, sample):
|
| 124 |
+
if np.random.random() < 0.5:
|
| 125 |
+
sample['left'] = np.copy(np.flipud(sample['left']))
|
| 126 |
+
sample['right'] = np.copy(np.flipud(sample['right']))
|
| 127 |
+
|
| 128 |
+
sample['disp'] = np.copy(np.flipud(sample['disp']))
|
| 129 |
+
|
| 130 |
+
return sample
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class ToPILImage(object):
|
| 134 |
+
|
| 135 |
+
def __call__(self, sample):
|
| 136 |
+
sample['left'] = Image.fromarray(sample['left'].astype('uint8'))
|
| 137 |
+
sample['right'] = Image.fromarray(sample['right'].astype('uint8'))
|
| 138 |
+
|
| 139 |
+
return sample
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class ToNumpyArray(object):
|
| 143 |
+
|
| 144 |
+
def __call__(self, sample):
|
| 145 |
+
sample['left'] = np.array(sample['left']).astype(np.float32)
|
| 146 |
+
sample['right'] = np.array(sample['right']).astype(np.float32)
|
| 147 |
+
|
| 148 |
+
return sample
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# Random coloring
|
| 152 |
+
class RandomContrast(object):
|
| 153 |
+
"""Random contrast"""
|
| 154 |
+
|
| 155 |
+
def __init__(self,
|
| 156 |
+
asymmetric_color_aug=True,
|
| 157 |
+
):
|
| 158 |
+
|
| 159 |
+
self.asymmetric_color_aug = asymmetric_color_aug
|
| 160 |
+
|
| 161 |
+
def __call__(self, sample):
|
| 162 |
+
if np.random.random() < 0.5:
|
| 163 |
+
contrast_factor = np.random.uniform(0.8, 1.2)
|
| 164 |
+
|
| 165 |
+
sample['left'] = F.adjust_contrast(sample['left'], contrast_factor)
|
| 166 |
+
|
| 167 |
+
if self.asymmetric_color_aug and np.random.random() < 0.5:
|
| 168 |
+
contrast_factor = np.random.uniform(0.8, 1.2)
|
| 169 |
+
|
| 170 |
+
sample['right'] = F.adjust_contrast(sample['right'], contrast_factor)
|
| 171 |
+
|
| 172 |
+
return sample
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class RandomGamma(object):
|
| 176 |
+
|
| 177 |
+
def __init__(self,
|
| 178 |
+
asymmetric_color_aug=True,
|
| 179 |
+
):
|
| 180 |
+
|
| 181 |
+
self.asymmetric_color_aug = asymmetric_color_aug
|
| 182 |
+
|
| 183 |
+
def __call__(self, sample):
|
| 184 |
+
if np.random.random() < 0.5:
|
| 185 |
+
gamma = np.random.uniform(0.7, 1.5) # adopted from FlowNet
|
| 186 |
+
|
| 187 |
+
sample['left'] = F.adjust_gamma(sample['left'], gamma)
|
| 188 |
+
|
| 189 |
+
if self.asymmetric_color_aug and np.random.random() < 0.5:
|
| 190 |
+
gamma = np.random.uniform(0.7, 1.5) # adopted from FlowNet
|
| 191 |
+
|
| 192 |
+
sample['right'] = F.adjust_gamma(sample['right'], gamma)
|
| 193 |
+
|
| 194 |
+
return sample
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class RandomBrightness(object):
|
| 198 |
+
|
| 199 |
+
def __init__(self,
|
| 200 |
+
asymmetric_color_aug=True,
|
| 201 |
+
):
|
| 202 |
+
|
| 203 |
+
self.asymmetric_color_aug = asymmetric_color_aug
|
| 204 |
+
|
| 205 |
+
def __call__(self, sample):
|
| 206 |
+
if np.random.random() < 0.5:
|
| 207 |
+
brightness = np.random.uniform(0.5, 2.0)
|
| 208 |
+
|
| 209 |
+
sample['left'] = F.adjust_brightness(sample['left'], brightness)
|
| 210 |
+
|
| 211 |
+
if self.asymmetric_color_aug and np.random.random() < 0.5:
|
| 212 |
+
brightness = np.random.uniform(0.5, 2.0)
|
| 213 |
+
|
| 214 |
+
sample['right'] = F.adjust_brightness(sample['right'], brightness)
|
| 215 |
+
|
| 216 |
+
return sample
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class RandomHue(object):
|
| 220 |
+
|
| 221 |
+
def __init__(self,
|
| 222 |
+
asymmetric_color_aug=True,
|
| 223 |
+
):
|
| 224 |
+
|
| 225 |
+
self.asymmetric_color_aug = asymmetric_color_aug
|
| 226 |
+
|
| 227 |
+
def __call__(self, sample):
|
| 228 |
+
if np.random.random() < 0.5:
|
| 229 |
+
hue = np.random.uniform(-0.1, 0.1)
|
| 230 |
+
|
| 231 |
+
sample['left'] = F.adjust_hue(sample['left'], hue)
|
| 232 |
+
|
| 233 |
+
if self.asymmetric_color_aug and np.random.random() < 0.5:
|
| 234 |
+
hue = np.random.uniform(-0.1, 0.1)
|
| 235 |
+
|
| 236 |
+
sample['right'] = F.adjust_hue(sample['right'], hue)
|
| 237 |
+
|
| 238 |
+
return sample
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class RandomSaturation(object):
|
| 242 |
+
|
| 243 |
+
def __init__(self,
|
| 244 |
+
asymmetric_color_aug=True,
|
| 245 |
+
):
|
| 246 |
+
|
| 247 |
+
self.asymmetric_color_aug = asymmetric_color_aug
|
| 248 |
+
|
| 249 |
+
def __call__(self, sample):
|
| 250 |
+
if np.random.random() < 0.5:
|
| 251 |
+
saturation = np.random.uniform(0.8, 1.2)
|
| 252 |
+
|
| 253 |
+
sample['left'] = F.adjust_saturation(sample['left'], saturation)
|
| 254 |
+
|
| 255 |
+
if self.asymmetric_color_aug and np.random.random() < 0.5:
|
| 256 |
+
saturation = np.random.uniform(0.8, 1.2)
|
| 257 |
+
|
| 258 |
+
sample['right'] = F.adjust_saturation(sample['right'], saturation)
|
| 259 |
+
|
| 260 |
+
return sample
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class RandomColor(object):
|
| 264 |
+
|
| 265 |
+
def __init__(self,
|
| 266 |
+
asymmetric_color_aug=True,
|
| 267 |
+
):
|
| 268 |
+
|
| 269 |
+
self.asymmetric_color_aug = asymmetric_color_aug
|
| 270 |
+
|
| 271 |
+
def __call__(self, sample):
|
| 272 |
+
transforms = [RandomContrast(asymmetric_color_aug=self.asymmetric_color_aug),
|
| 273 |
+
RandomGamma(asymmetric_color_aug=self.asymmetric_color_aug),
|
| 274 |
+
RandomBrightness(asymmetric_color_aug=self.asymmetric_color_aug),
|
| 275 |
+
RandomHue(asymmetric_color_aug=self.asymmetric_color_aug),
|
| 276 |
+
RandomSaturation(asymmetric_color_aug=self.asymmetric_color_aug)]
|
| 277 |
+
|
| 278 |
+
sample = ToPILImage()(sample)
|
| 279 |
+
|
| 280 |
+
if np.random.random() < 0.5:
|
| 281 |
+
# A single transform
|
| 282 |
+
t = random.choice(transforms)
|
| 283 |
+
sample = t(sample)
|
| 284 |
+
else:
|
| 285 |
+
# Combination of transforms
|
| 286 |
+
# Random order
|
| 287 |
+
random.shuffle(transforms)
|
| 288 |
+
for t in transforms:
|
| 289 |
+
sample = t(sample)
|
| 290 |
+
|
| 291 |
+
sample = ToNumpyArray()(sample)
|
| 292 |
+
|
| 293 |
+
return sample
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class RandomScale(object):
|
| 297 |
+
def __init__(self,
|
| 298 |
+
min_scale=-0.4,
|
| 299 |
+
max_scale=0.4,
|
| 300 |
+
crop_width=512,
|
| 301 |
+
nearest_interp=False, # for sparse gt
|
| 302 |
+
):
|
| 303 |
+
self.min_scale = min_scale
|
| 304 |
+
self.max_scale = max_scale
|
| 305 |
+
self.crop_width = crop_width
|
| 306 |
+
self.nearest_interp = nearest_interp
|
| 307 |
+
|
| 308 |
+
def __call__(self, sample):
|
| 309 |
+
if np.random.rand() < 0.5:
|
| 310 |
+
h, w = sample['disp'].shape
|
| 311 |
+
|
| 312 |
+
scale_x = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
| 313 |
+
|
| 314 |
+
scale_x = np.clip(scale_x, self.crop_width / float(w), None)
|
| 315 |
+
|
| 316 |
+
# only random scale x axis
|
| 317 |
+
sample['left'] = cv2.resize(sample['left'], None, fx=scale_x, fy=1., interpolation=cv2.INTER_LINEAR)
|
| 318 |
+
sample['right'] = cv2.resize(sample['right'], None, fx=scale_x, fy=1., interpolation=cv2.INTER_LINEAR)
|
| 319 |
+
|
| 320 |
+
sample['disp'] = cv2.resize(
|
| 321 |
+
sample['disp'], None, fx=scale_x, fy=1.,
|
| 322 |
+
interpolation=cv2.INTER_LINEAR if not self.nearest_interp else cv2.INTER_NEAREST
|
| 323 |
+
) * scale_x
|
| 324 |
+
|
| 325 |
+
if 'pseudo_disp' in sample and sample['pseudo_disp'] is not None:
|
| 326 |
+
sample['pseudo_disp'] = cv2.resize(sample['pseudo_disp'], None, fx=scale_x, fy=1.,
|
| 327 |
+
interpolation=cv2.INTER_LINEAR) * scale_x
|
| 328 |
+
|
| 329 |
+
return sample
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
class Resize(object):
|
| 333 |
+
def __init__(self,
|
| 334 |
+
scale_x=1,
|
| 335 |
+
scale_y=1,
|
| 336 |
+
nearest_interp=True, # for sparse gt
|
| 337 |
+
):
|
| 338 |
+
"""
|
| 339 |
+
Resize low-resolution data to high-res for mixed dataset training
|
| 340 |
+
"""
|
| 341 |
+
self.scale_x = scale_x
|
| 342 |
+
self.scale_y = scale_y
|
| 343 |
+
self.nearest_interp = nearest_interp
|
| 344 |
+
|
| 345 |
+
def __call__(self, sample):
|
| 346 |
+
scale_x = self.scale_x
|
| 347 |
+
scale_y = self.scale_y
|
| 348 |
+
|
| 349 |
+
sample['left'] = cv2.resize(sample['left'], None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
| 350 |
+
sample['right'] = cv2.resize(sample['right'], None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
| 351 |
+
|
| 352 |
+
sample['disp'] = cv2.resize(
|
| 353 |
+
sample['disp'], None, fx=scale_x, fy=scale_y,
|
| 354 |
+
interpolation=cv2.INTER_LINEAR if not self.nearest_interp else cv2.INTER_NEAREST
|
| 355 |
+
) * scale_x
|
| 356 |
+
|
| 357 |
+
return sample
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class RandomGrayscale(object):
|
| 361 |
+
def __init__(self, p=0.2):
|
| 362 |
+
self.p = p
|
| 363 |
+
|
| 364 |
+
def __call__(self, sample):
|
| 365 |
+
if np.random.random() < self.p:
|
| 366 |
+
sample = ToPILImage()(sample)
|
| 367 |
+
|
| 368 |
+
# only supported in higher version pytorch
|
| 369 |
+
# default output channels is 1
|
| 370 |
+
sample['left'] = F.rgb_to_grayscale(sample['left'], num_output_channels=3)
|
| 371 |
+
sample['right'] = F.rgb_to_grayscale(sample['right'], num_output_channels=3)
|
| 372 |
+
|
| 373 |
+
sample = ToNumpyArray()(sample)
|
| 374 |
+
|
| 375 |
+
return sample
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class RandomRotateShiftRight(object):
|
| 379 |
+
def __init__(self, p=0.5):
|
| 380 |
+
self.p = p
|
| 381 |
+
|
| 382 |
+
def __call__(self, sample):
|
| 383 |
+
if np.random.random() < self.p:
|
| 384 |
+
angle, pixel = 0.1, 2
|
| 385 |
+
px = np.random.uniform(-pixel, pixel)
|
| 386 |
+
ag = np.random.uniform(-angle, angle)
|
| 387 |
+
|
| 388 |
+
right_img = sample['right']
|
| 389 |
+
|
| 390 |
+
image_center = (
|
| 391 |
+
np.random.uniform(0, right_img.shape[0]),
|
| 392 |
+
np.random.uniform(0, right_img.shape[1])
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
rot_mat = cv2.getRotationMatrix2D(image_center, ag, 1.0)
|
| 396 |
+
right_img = cv2.warpAffine(
|
| 397 |
+
right_img, rot_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR
|
| 398 |
+
)
|
| 399 |
+
trans_mat = np.float32([[1, 0, 0], [0, 1, px]])
|
| 400 |
+
right_img = cv2.warpAffine(
|
| 401 |
+
right_img, trans_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
sample['right'] = right_img
|
| 405 |
+
|
| 406 |
+
return sample
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class RandomOcclusion(object):
|
| 410 |
+
def __init__(self, p=0.5,
|
| 411 |
+
occlusion_mask_zero=False):
|
| 412 |
+
self.p = p
|
| 413 |
+
self.occlusion_mask_zero = occlusion_mask_zero
|
| 414 |
+
|
| 415 |
+
def __call__(self, sample):
|
| 416 |
+
bounds = [50, 100]
|
| 417 |
+
if np.random.random() < self.p:
|
| 418 |
+
img2 = sample['right']
|
| 419 |
+
ht, wd = img2.shape[:2]
|
| 420 |
+
|
| 421 |
+
if self.occlusion_mask_zero:
|
| 422 |
+
mean_color = 0
|
| 423 |
+
else:
|
| 424 |
+
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
| 425 |
+
|
| 426 |
+
x0 = np.random.randint(0, wd)
|
| 427 |
+
y0 = np.random.randint(0, ht)
|
| 428 |
+
dx = np.random.randint(bounds[0], bounds[1])
|
| 429 |
+
dy = np.random.randint(bounds[0], bounds[1])
|
| 430 |
+
img2[y0:y0 + dy, x0:x0 + dx, :] = mean_color
|
| 431 |
+
|
| 432 |
+
sample['right'] = img2
|
| 433 |
+
|
| 434 |
+
return sample
|
demo/flow_davis_skate-jump_00059.jpg
ADDED
|
demo/flow_davis_skate-jump_00060.jpg
ADDED
|
demo/flow_kitti_test_000197_10.png
ADDED
|
demo/flow_kitti_test_000197_11.png
ADDED
|
demo/flow_sintel_cave_3_frame_0049.png
ADDED
|
demo/flow_sintel_cave_3_frame_0050.png
ADDED
|
demo/stereo_drivingstereo_test_2018-07-11-14-48-52_2018-07-11-14-58-34-673_left.jpg
ADDED
|
demo/stereo_drivingstereo_test_2018-07-11-14-48-52_2018-07-11-14-58-34-673_right.jpg
ADDED
|
pretrained/tmp.txt
ADDED
|
File without changes
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
matplotlib
|
| 4 |
+
opencv-python
|
| 5 |
+
pillow
|
unimatch/__init__.py
ADDED
|
File without changes
|
unimatch/attention.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from .utils import split_feature, merge_splits, split_feature_1d, merge_splits_1d
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def single_head_full_attention(q, k, v):
|
| 9 |
+
# q, k, v: [B, L, C]
|
| 10 |
+
assert q.dim() == k.dim() == v.dim() == 3
|
| 11 |
+
|
| 12 |
+
scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L]
|
| 13 |
+
attn = torch.softmax(scores, dim=2) # [B, L, L]
|
| 14 |
+
out = torch.matmul(attn, v) # [B, L, C]
|
| 15 |
+
|
| 16 |
+
return out
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def single_head_full_attention_1d(q, k, v,
|
| 20 |
+
h=None,
|
| 21 |
+
w=None,
|
| 22 |
+
):
|
| 23 |
+
# q, k, v: [B, L, C]
|
| 24 |
+
|
| 25 |
+
assert h is not None and w is not None
|
| 26 |
+
assert q.size(1) == h * w
|
| 27 |
+
|
| 28 |
+
b, _, c = q.size()
|
| 29 |
+
|
| 30 |
+
q = q.view(b, h, w, c) # [B, H, W, C]
|
| 31 |
+
k = k.view(b, h, w, c)
|
| 32 |
+
v = v.view(b, h, w, c)
|
| 33 |
+
|
| 34 |
+
scale_factor = c ** 0.5
|
| 35 |
+
|
| 36 |
+
scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor # [B, H, W, W]
|
| 37 |
+
|
| 38 |
+
attn = torch.softmax(scores, dim=-1)
|
| 39 |
+
|
| 40 |
+
out = torch.matmul(attn, v).view(b, -1, c) # [B, H*W, C]
|
| 41 |
+
|
| 42 |
+
return out
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def single_head_split_window_attention(q, k, v,
|
| 46 |
+
num_splits=1,
|
| 47 |
+
with_shift=False,
|
| 48 |
+
h=None,
|
| 49 |
+
w=None,
|
| 50 |
+
attn_mask=None,
|
| 51 |
+
):
|
| 52 |
+
# ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
|
| 53 |
+
# q, k, v: [B, L, C]
|
| 54 |
+
assert q.dim() == k.dim() == v.dim() == 3
|
| 55 |
+
|
| 56 |
+
assert h is not None and w is not None
|
| 57 |
+
assert q.size(1) == h * w
|
| 58 |
+
|
| 59 |
+
b, _, c = q.size()
|
| 60 |
+
|
| 61 |
+
b_new = b * num_splits * num_splits
|
| 62 |
+
|
| 63 |
+
window_size_h = h // num_splits
|
| 64 |
+
window_size_w = w // num_splits
|
| 65 |
+
|
| 66 |
+
q = q.view(b, h, w, c) # [B, H, W, C]
|
| 67 |
+
k = k.view(b, h, w, c)
|
| 68 |
+
v = v.view(b, h, w, c)
|
| 69 |
+
|
| 70 |
+
scale_factor = c ** 0.5
|
| 71 |
+
|
| 72 |
+
if with_shift:
|
| 73 |
+
assert attn_mask is not None # compute once
|
| 74 |
+
shift_size_h = window_size_h // 2
|
| 75 |
+
shift_size_w = window_size_w // 2
|
| 76 |
+
|
| 77 |
+
q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
| 78 |
+
k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
| 79 |
+
v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
| 80 |
+
|
| 81 |
+
q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C]
|
| 82 |
+
k = split_feature(k, num_splits=num_splits, channel_last=True)
|
| 83 |
+
v = split_feature(v, num_splits=num_splits, channel_last=True)
|
| 84 |
+
|
| 85 |
+
scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
|
| 86 |
+
) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K]
|
| 87 |
+
|
| 88 |
+
if with_shift:
|
| 89 |
+
scores += attn_mask.repeat(b, 1, 1)
|
| 90 |
+
|
| 91 |
+
attn = torch.softmax(scores, dim=-1)
|
| 92 |
+
|
| 93 |
+
out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C]
|
| 94 |
+
|
| 95 |
+
out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c),
|
| 96 |
+
num_splits=num_splits, channel_last=True) # [B, H, W, C]
|
| 97 |
+
|
| 98 |
+
# shift back
|
| 99 |
+
if with_shift:
|
| 100 |
+
out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
|
| 101 |
+
|
| 102 |
+
out = out.view(b, -1, c)
|
| 103 |
+
|
| 104 |
+
return out
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def single_head_split_window_attention_1d(q, k, v,
|
| 108 |
+
relative_position_bias=None,
|
| 109 |
+
num_splits=1,
|
| 110 |
+
with_shift=False,
|
| 111 |
+
h=None,
|
| 112 |
+
w=None,
|
| 113 |
+
attn_mask=None,
|
| 114 |
+
):
|
| 115 |
+
# q, k, v: [B, L, C]
|
| 116 |
+
|
| 117 |
+
assert h is not None and w is not None
|
| 118 |
+
assert q.size(1) == h * w
|
| 119 |
+
|
| 120 |
+
b, _, c = q.size()
|
| 121 |
+
|
| 122 |
+
b_new = b * num_splits * h
|
| 123 |
+
|
| 124 |
+
window_size_w = w // num_splits
|
| 125 |
+
|
| 126 |
+
q = q.view(b * h, w, c) # [B*H, W, C]
|
| 127 |
+
k = k.view(b * h, w, c)
|
| 128 |
+
v = v.view(b * h, w, c)
|
| 129 |
+
|
| 130 |
+
scale_factor = c ** 0.5
|
| 131 |
+
|
| 132 |
+
if with_shift:
|
| 133 |
+
assert attn_mask is not None # compute once
|
| 134 |
+
shift_size_w = window_size_w // 2
|
| 135 |
+
|
| 136 |
+
q = torch.roll(q, shifts=-shift_size_w, dims=1)
|
| 137 |
+
k = torch.roll(k, shifts=-shift_size_w, dims=1)
|
| 138 |
+
v = torch.roll(v, shifts=-shift_size_w, dims=1)
|
| 139 |
+
|
| 140 |
+
q = split_feature_1d(q, num_splits=num_splits) # [B*H*K, W/K, C]
|
| 141 |
+
k = split_feature_1d(k, num_splits=num_splits)
|
| 142 |
+
v = split_feature_1d(v, num_splits=num_splits)
|
| 143 |
+
|
| 144 |
+
scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
|
| 145 |
+
) / scale_factor # [B*H*K, W/K, W/K]
|
| 146 |
+
|
| 147 |
+
if with_shift:
|
| 148 |
+
# attn_mask: [K, W/K, W/K]
|
| 149 |
+
scores += attn_mask.repeat(b * h, 1, 1) # [B*H*K, W/K, W/K]
|
| 150 |
+
|
| 151 |
+
attn = torch.softmax(scores, dim=-1)
|
| 152 |
+
|
| 153 |
+
out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*H*K, W/K, C]
|
| 154 |
+
|
| 155 |
+
out = merge_splits_1d(out, h, num_splits=num_splits) # [B, H, W, C]
|
| 156 |
+
|
| 157 |
+
# shift back
|
| 158 |
+
if with_shift:
|
| 159 |
+
out = torch.roll(out, shifts=shift_size_w, dims=2)
|
| 160 |
+
|
| 161 |
+
out = out.view(b, -1, c)
|
| 162 |
+
|
| 163 |
+
return out
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class SelfAttnPropagation(nn.Module):
|
| 167 |
+
"""
|
| 168 |
+
flow propagation with self-attention on feature
|
| 169 |
+
query: feature0, key: feature0, value: flow
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
def __init__(self, in_channels,
|
| 173 |
+
**kwargs,
|
| 174 |
+
):
|
| 175 |
+
super(SelfAttnPropagation, self).__init__()
|
| 176 |
+
|
| 177 |
+
self.q_proj = nn.Linear(in_channels, in_channels)
|
| 178 |
+
self.k_proj = nn.Linear(in_channels, in_channels)
|
| 179 |
+
|
| 180 |
+
for p in self.parameters():
|
| 181 |
+
if p.dim() > 1:
|
| 182 |
+
nn.init.xavier_uniform_(p)
|
| 183 |
+
|
| 184 |
+
def forward(self, feature0, flow,
|
| 185 |
+
local_window_attn=False,
|
| 186 |
+
local_window_radius=1,
|
| 187 |
+
**kwargs,
|
| 188 |
+
):
|
| 189 |
+
# q, k: feature [B, C, H, W], v: flow [B, 2, H, W]
|
| 190 |
+
if local_window_attn:
|
| 191 |
+
return self.forward_local_window_attn(feature0, flow,
|
| 192 |
+
local_window_radius=local_window_radius)
|
| 193 |
+
|
| 194 |
+
b, c, h, w = feature0.size()
|
| 195 |
+
|
| 196 |
+
query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C]
|
| 197 |
+
|
| 198 |
+
# a note: the ``correct'' implementation should be:
|
| 199 |
+
# ``query = self.q_proj(query), key = self.k_proj(query)''
|
| 200 |
+
# this problem is observed while cleaning up the code
|
| 201 |
+
# however, this doesn't affect the performance since the projection is a linear operation,
|
| 202 |
+
# thus the two projection matrices for key can be merged
|
| 203 |
+
# so I just leave it as is in order to not re-train all models :)
|
| 204 |
+
query = self.q_proj(query) # [B, H*W, C]
|
| 205 |
+
key = self.k_proj(query) # [B, H*W, C]
|
| 206 |
+
|
| 207 |
+
value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2]
|
| 208 |
+
|
| 209 |
+
scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W]
|
| 210 |
+
prob = torch.softmax(scores, dim=-1)
|
| 211 |
+
|
| 212 |
+
out = torch.matmul(prob, value) # [B, H*W, 2]
|
| 213 |
+
out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W]
|
| 214 |
+
|
| 215 |
+
return out
|
| 216 |
+
|
| 217 |
+
def forward_local_window_attn(self, feature0, flow,
|
| 218 |
+
local_window_radius=1,
|
| 219 |
+
):
|
| 220 |
+
assert flow.size(1) == 2 or flow.size(1) == 1 # flow or disparity or depth
|
| 221 |
+
assert local_window_radius > 0
|
| 222 |
+
|
| 223 |
+
b, c, h, w = feature0.size()
|
| 224 |
+
|
| 225 |
+
value_channel = flow.size(1)
|
| 226 |
+
|
| 227 |
+
feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1)
|
| 228 |
+
).reshape(b * h * w, 1, c) # [B*H*W, 1, C]
|
| 229 |
+
|
| 230 |
+
kernel_size = 2 * local_window_radius + 1
|
| 231 |
+
|
| 232 |
+
feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w)
|
| 233 |
+
|
| 234 |
+
feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size,
|
| 235 |
+
padding=local_window_radius) # [B, C*(2R+1)^2), H*W]
|
| 236 |
+
|
| 237 |
+
feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute(
|
| 238 |
+
0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2]
|
| 239 |
+
|
| 240 |
+
flow_window = F.unfold(flow, kernel_size=kernel_size,
|
| 241 |
+
padding=local_window_radius) # [B, 2*(2R+1)^2), H*W]
|
| 242 |
+
|
| 243 |
+
flow_window = flow_window.view(b, value_channel, kernel_size ** 2, h, w).permute(
|
| 244 |
+
0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, value_channel) # [B*H*W, (2R+1)^2, 2]
|
| 245 |
+
|
| 246 |
+
scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2]
|
| 247 |
+
|
| 248 |
+
prob = torch.softmax(scores, dim=-1)
|
| 249 |
+
|
| 250 |
+
out = torch.matmul(prob, flow_window).view(b, h, w, value_channel
|
| 251 |
+
).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W]
|
| 252 |
+
|
| 253 |
+
return out
|
unimatch/backbone.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
from .trident_conv import MultiScaleTridentConv
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ResidualBlock(nn.Module):
|
| 7 |
+
def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1,
|
| 8 |
+
):
|
| 9 |
+
super(ResidualBlock, self).__init__()
|
| 10 |
+
|
| 11 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
|
| 12 |
+
dilation=dilation, padding=dilation, stride=stride, bias=False)
|
| 13 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
|
| 14 |
+
dilation=dilation, padding=dilation, bias=False)
|
| 15 |
+
self.relu = nn.ReLU(inplace=True)
|
| 16 |
+
|
| 17 |
+
self.norm1 = norm_layer(planes)
|
| 18 |
+
self.norm2 = norm_layer(planes)
|
| 19 |
+
if not stride == 1 or in_planes != planes:
|
| 20 |
+
self.norm3 = norm_layer(planes)
|
| 21 |
+
|
| 22 |
+
if stride == 1 and in_planes == planes:
|
| 23 |
+
self.downsample = None
|
| 24 |
+
else:
|
| 25 |
+
self.downsample = nn.Sequential(
|
| 26 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
y = x
|
| 30 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
| 31 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
| 32 |
+
|
| 33 |
+
if self.downsample is not None:
|
| 34 |
+
x = self.downsample(x)
|
| 35 |
+
|
| 36 |
+
return self.relu(x + y)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class CNNEncoder(nn.Module):
|
| 40 |
+
def __init__(self, output_dim=128,
|
| 41 |
+
norm_layer=nn.InstanceNorm2d,
|
| 42 |
+
num_output_scales=1,
|
| 43 |
+
**kwargs,
|
| 44 |
+
):
|
| 45 |
+
super(CNNEncoder, self).__init__()
|
| 46 |
+
self.num_branch = num_output_scales
|
| 47 |
+
|
| 48 |
+
feature_dims = [64, 96, 128]
|
| 49 |
+
|
| 50 |
+
self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2
|
| 51 |
+
self.norm1 = norm_layer(feature_dims[0])
|
| 52 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 53 |
+
|
| 54 |
+
self.in_planes = feature_dims[0]
|
| 55 |
+
self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2
|
| 56 |
+
self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4
|
| 57 |
+
|
| 58 |
+
# highest resolution 1/4 or 1/8
|
| 59 |
+
stride = 2 if num_output_scales == 1 else 1
|
| 60 |
+
self.layer3 = self._make_layer(feature_dims[2], stride=stride,
|
| 61 |
+
norm_layer=norm_layer,
|
| 62 |
+
) # 1/4 or 1/8
|
| 63 |
+
|
| 64 |
+
self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
|
| 65 |
+
|
| 66 |
+
if self.num_branch > 1:
|
| 67 |
+
if self.num_branch == 4:
|
| 68 |
+
strides = (1, 2, 4, 8)
|
| 69 |
+
elif self.num_branch == 3:
|
| 70 |
+
strides = (1, 2, 4)
|
| 71 |
+
elif self.num_branch == 2:
|
| 72 |
+
strides = (1, 2)
|
| 73 |
+
else:
|
| 74 |
+
raise ValueError
|
| 75 |
+
|
| 76 |
+
self.trident_conv = MultiScaleTridentConv(output_dim, output_dim,
|
| 77 |
+
kernel_size=3,
|
| 78 |
+
strides=strides,
|
| 79 |
+
paddings=1,
|
| 80 |
+
num_branch=self.num_branch,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
for m in self.modules():
|
| 84 |
+
if isinstance(m, nn.Conv2d):
|
| 85 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 86 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
| 87 |
+
if m.weight is not None:
|
| 88 |
+
nn.init.constant_(m.weight, 1)
|
| 89 |
+
if m.bias is not None:
|
| 90 |
+
nn.init.constant_(m.bias, 0)
|
| 91 |
+
|
| 92 |
+
def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
|
| 93 |
+
layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation)
|
| 94 |
+
layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation)
|
| 95 |
+
|
| 96 |
+
layers = (layer1, layer2)
|
| 97 |
+
|
| 98 |
+
self.in_planes = dim
|
| 99 |
+
return nn.Sequential(*layers)
|
| 100 |
+
|
| 101 |
+
def forward(self, x):
|
| 102 |
+
x = self.conv1(x)
|
| 103 |
+
x = self.norm1(x)
|
| 104 |
+
x = self.relu1(x)
|
| 105 |
+
|
| 106 |
+
x = self.layer1(x) # 1/2
|
| 107 |
+
x = self.layer2(x) # 1/4
|
| 108 |
+
x = self.layer3(x) # 1/8 or 1/4
|
| 109 |
+
|
| 110 |
+
x = self.conv2(x)
|
| 111 |
+
|
| 112 |
+
if self.num_branch > 1:
|
| 113 |
+
out = self.trident_conv([x] * self.num_branch) # high to low res
|
| 114 |
+
else:
|
| 115 |
+
out = [x]
|
| 116 |
+
|
| 117 |
+
return out
|
unimatch/geometry.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def coords_grid(b, h, w, homogeneous=False, device=None):
|
| 6 |
+
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
|
| 7 |
+
|
| 8 |
+
stacks = [x, y]
|
| 9 |
+
|
| 10 |
+
if homogeneous:
|
| 11 |
+
ones = torch.ones_like(x) # [H, W]
|
| 12 |
+
stacks.append(ones)
|
| 13 |
+
|
| 14 |
+
grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
|
| 15 |
+
|
| 16 |
+
grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
|
| 17 |
+
|
| 18 |
+
if device is not None:
|
| 19 |
+
grid = grid.to(device)
|
| 20 |
+
|
| 21 |
+
return grid
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
|
| 25 |
+
assert device is not None
|
| 26 |
+
|
| 27 |
+
x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device),
|
| 28 |
+
torch.linspace(h_min, h_max, len_h, device=device)],
|
| 29 |
+
)
|
| 30 |
+
grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
|
| 31 |
+
|
| 32 |
+
return grid
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def normalize_coords(coords, h, w):
|
| 36 |
+
# coords: [B, H, W, 2]
|
| 37 |
+
c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
|
| 38 |
+
return (coords - c) / c # [-1, 1]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False):
|
| 42 |
+
# img: [B, C, H, W]
|
| 43 |
+
# sample_coords: [B, 2, H, W] in image scale
|
| 44 |
+
if sample_coords.size(1) != 2: # [B, H, W, 2]
|
| 45 |
+
sample_coords = sample_coords.permute(0, 3, 1, 2)
|
| 46 |
+
|
| 47 |
+
b, _, h, w = sample_coords.shape
|
| 48 |
+
|
| 49 |
+
# Normalize to [-1, 1]
|
| 50 |
+
x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
|
| 51 |
+
y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
|
| 52 |
+
|
| 53 |
+
grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
|
| 54 |
+
|
| 55 |
+
img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True)
|
| 56 |
+
|
| 57 |
+
if return_mask:
|
| 58 |
+
mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W]
|
| 59 |
+
|
| 60 |
+
return img, mask
|
| 61 |
+
|
| 62 |
+
return img
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def flow_warp(feature, flow, mask=False, padding_mode='zeros'):
|
| 66 |
+
b, c, h, w = feature.size()
|
| 67 |
+
assert flow.size(1) == 2
|
| 68 |
+
|
| 69 |
+
grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
|
| 70 |
+
|
| 71 |
+
return bilinear_sample(feature, grid, padding_mode=padding_mode,
|
| 72 |
+
return_mask=mask)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def forward_backward_consistency_check(fwd_flow, bwd_flow,
|
| 76 |
+
alpha=0.01,
|
| 77 |
+
beta=0.5
|
| 78 |
+
):
|
| 79 |
+
# fwd_flow, bwd_flow: [B, 2, H, W]
|
| 80 |
+
# alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
|
| 81 |
+
assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
|
| 82 |
+
assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
|
| 83 |
+
flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
|
| 84 |
+
|
| 85 |
+
warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
|
| 86 |
+
warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
|
| 87 |
+
|
| 88 |
+
diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
|
| 89 |
+
diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
|
| 90 |
+
|
| 91 |
+
threshold = alpha * flow_mag + beta
|
| 92 |
+
|
| 93 |
+
fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
|
| 94 |
+
bwd_occ = (diff_bwd > threshold).float()
|
| 95 |
+
|
| 96 |
+
return fwd_occ, bwd_occ
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def back_project(depth, intrinsics):
|
| 100 |
+
# Back project 2D pixel coords to 3D points
|
| 101 |
+
# depth: [B, H, W]
|
| 102 |
+
# intrinsics: [B, 3, 3]
|
| 103 |
+
b, h, w = depth.shape
|
| 104 |
+
grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W]
|
| 105 |
+
|
| 106 |
+
intrinsics_inv = torch.inverse(intrinsics) # [B, 3, 3]
|
| 107 |
+
|
| 108 |
+
points = intrinsics_inv.bmm(grid.view(b, 3, -1)).view(b, 3, h, w) * depth.unsqueeze(1) # [B, 3, H, W]
|
| 109 |
+
|
| 110 |
+
return points
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None):
|
| 114 |
+
# Transform 3D points from reference camera to target camera
|
| 115 |
+
# points_ref: [B, 3, H, W]
|
| 116 |
+
# extrinsics_ref: [B, 4, 4]
|
| 117 |
+
# extrinsics_tgt: [B, 4, 4]
|
| 118 |
+
# extrinsics_rel: [B, 4, 4], relative pose transform
|
| 119 |
+
b, _, h, w = points_ref.shape
|
| 120 |
+
|
| 121 |
+
if extrinsics_rel is None:
|
| 122 |
+
extrinsics_rel = torch.bmm(extrinsics_tgt, torch.inverse(extrinsics_ref)) # [B, 4, 4]
|
| 123 |
+
|
| 124 |
+
points_tgt = torch.bmm(extrinsics_rel[:, :3, :3],
|
| 125 |
+
points_ref.view(b, 3, -1)) + extrinsics_rel[:, :3, -1:] # [B, 3, H*W]
|
| 126 |
+
|
| 127 |
+
points_tgt = points_tgt.view(b, 3, h, w) # [B, 3, H, W]
|
| 128 |
+
|
| 129 |
+
return points_tgt
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def reproject(points_tgt, intrinsics, return_mask=False):
|
| 133 |
+
# reproject to target view
|
| 134 |
+
# points_tgt: [B, 3, H, W]
|
| 135 |
+
# intrinsics: [B, 3, 3]
|
| 136 |
+
|
| 137 |
+
b, _, h, w = points_tgt.shape
|
| 138 |
+
|
| 139 |
+
proj_points = torch.bmm(intrinsics, points_tgt.view(b, 3, -1)).view(b, 3, h, w) # [B, 3, H, W]
|
| 140 |
+
|
| 141 |
+
X = proj_points[:, 0]
|
| 142 |
+
Y = proj_points[:, 1]
|
| 143 |
+
Z = proj_points[:, 2].clamp(min=1e-3)
|
| 144 |
+
|
| 145 |
+
pixel_coords = torch.stack([X / Z, Y / Z], dim=1).view(b, 2, h, w) # [B, 2, H, W] in image scale
|
| 146 |
+
|
| 147 |
+
if return_mask:
|
| 148 |
+
# valid mask in pixel space
|
| 149 |
+
mask = (pixel_coords[:, 0] >= 0) & (pixel_coords[:, 0] <= (w - 1)) & (
|
| 150 |
+
pixel_coords[:, 1] >= 0) & (pixel_coords[:, 1] <= (h - 1)) # [B, H, W]
|
| 151 |
+
|
| 152 |
+
return pixel_coords, mask
|
| 153 |
+
|
| 154 |
+
return pixel_coords
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def reproject_coords(depth_ref, intrinsics, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None,
|
| 158 |
+
return_mask=False):
|
| 159 |
+
# Compute reprojection sample coords
|
| 160 |
+
points_ref = back_project(depth_ref, intrinsics) # [B, 3, H, W]
|
| 161 |
+
points_tgt = camera_transform(points_ref, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel)
|
| 162 |
+
|
| 163 |
+
if return_mask:
|
| 164 |
+
reproj_coords, mask = reproject(points_tgt, intrinsics,
|
| 165 |
+
return_mask=return_mask) # [B, 2, H, W] in image scale
|
| 166 |
+
|
| 167 |
+
return reproj_coords, mask
|
| 168 |
+
|
| 169 |
+
reproj_coords = reproject(points_tgt, intrinsics,
|
| 170 |
+
return_mask=return_mask) # [B, 2, H, W] in image scale
|
| 171 |
+
|
| 172 |
+
return reproj_coords
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def compute_flow_with_depth_pose(depth_ref, intrinsics,
|
| 176 |
+
extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None,
|
| 177 |
+
return_mask=False):
|
| 178 |
+
b, h, w = depth_ref.shape
|
| 179 |
+
coords_init = coords_grid(b, h, w, device=depth_ref.device) # [B, 2, H, W]
|
| 180 |
+
|
| 181 |
+
if return_mask:
|
| 182 |
+
reproj_coords, mask = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt,
|
| 183 |
+
extrinsics_rel=extrinsics_rel,
|
| 184 |
+
return_mask=return_mask) # [B, 2, H, W]
|
| 185 |
+
rigid_flow = reproj_coords - coords_init
|
| 186 |
+
|
| 187 |
+
return rigid_flow, mask
|
| 188 |
+
|
| 189 |
+
reproj_coords = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt,
|
| 190 |
+
extrinsics_rel=extrinsics_rel,
|
| 191 |
+
return_mask=return_mask) # [B, 2, H, W]
|
| 192 |
+
|
| 193 |
+
rigid_flow = reproj_coords - coords_init
|
| 194 |
+
|
| 195 |
+
return rigid_flow
|
unimatch/matching.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
from .geometry import coords_grid, generate_window_grid, normalize_coords
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def global_correlation_softmax(feature0, feature1,
|
| 8 |
+
pred_bidir_flow=False,
|
| 9 |
+
):
|
| 10 |
+
# global correlation
|
| 11 |
+
b, c, h, w = feature0.shape
|
| 12 |
+
feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
|
| 13 |
+
feature1 = feature1.view(b, c, -1) # [B, C, H*W]
|
| 14 |
+
|
| 15 |
+
correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W]
|
| 16 |
+
|
| 17 |
+
# flow from softmax
|
| 18 |
+
init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W]
|
| 19 |
+
grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
|
| 20 |
+
|
| 21 |
+
correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
|
| 22 |
+
|
| 23 |
+
if pred_bidir_flow:
|
| 24 |
+
correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W]
|
| 25 |
+
init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W]
|
| 26 |
+
grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2]
|
| 27 |
+
b = b * 2
|
| 28 |
+
|
| 29 |
+
prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
|
| 30 |
+
|
| 31 |
+
correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
|
| 32 |
+
|
| 33 |
+
# when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow
|
| 34 |
+
flow = correspondence - init_grid
|
| 35 |
+
|
| 36 |
+
return flow, prob
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def local_correlation_softmax(feature0, feature1, local_radius,
|
| 40 |
+
padding_mode='zeros',
|
| 41 |
+
):
|
| 42 |
+
b, c, h, w = feature0.size()
|
| 43 |
+
coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
|
| 44 |
+
coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
|
| 45 |
+
|
| 46 |
+
local_h = 2 * local_radius + 1
|
| 47 |
+
local_w = 2 * local_radius + 1
|
| 48 |
+
|
| 49 |
+
window_grid = generate_window_grid(-local_radius, local_radius,
|
| 50 |
+
-local_radius, local_radius,
|
| 51 |
+
local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2]
|
| 52 |
+
window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
|
| 53 |
+
sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2]
|
| 54 |
+
|
| 55 |
+
sample_coords_softmax = sample_coords
|
| 56 |
+
|
| 57 |
+
# exclude coords that are out of image space
|
| 58 |
+
valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2]
|
| 59 |
+
valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2]
|
| 60 |
+
|
| 61 |
+
valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
|
| 62 |
+
|
| 63 |
+
# normalize coordinates to [-1, 1]
|
| 64 |
+
sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
|
| 65 |
+
window_feature = F.grid_sample(feature1, sample_coords_norm,
|
| 66 |
+
padding_mode=padding_mode, align_corners=True
|
| 67 |
+
).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2]
|
| 68 |
+
feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
|
| 69 |
+
|
| 70 |
+
corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2]
|
| 71 |
+
|
| 72 |
+
# mask invalid locations
|
| 73 |
+
corr[~valid] = -1e9
|
| 74 |
+
|
| 75 |
+
prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2]
|
| 76 |
+
|
| 77 |
+
correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view(
|
| 78 |
+
b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
|
| 79 |
+
|
| 80 |
+
flow = correspondence - coords_init
|
| 81 |
+
match_prob = prob
|
| 82 |
+
|
| 83 |
+
return flow, match_prob
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def local_correlation_with_flow(feature0, feature1,
|
| 87 |
+
flow,
|
| 88 |
+
local_radius,
|
| 89 |
+
padding_mode='zeros',
|
| 90 |
+
dilation=1,
|
| 91 |
+
):
|
| 92 |
+
b, c, h, w = feature0.size()
|
| 93 |
+
coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
|
| 94 |
+
coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
|
| 95 |
+
|
| 96 |
+
local_h = 2 * local_radius + 1
|
| 97 |
+
local_w = 2 * local_radius + 1
|
| 98 |
+
|
| 99 |
+
window_grid = generate_window_grid(-local_radius, local_radius,
|
| 100 |
+
-local_radius, local_radius,
|
| 101 |
+
local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2]
|
| 102 |
+
window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
|
| 103 |
+
sample_coords = coords.unsqueeze(-2) + window_grid * dilation # [B, H*W, (2R+1)^2, 2]
|
| 104 |
+
|
| 105 |
+
# flow can be zero when using features after transformer
|
| 106 |
+
if not isinstance(flow, float):
|
| 107 |
+
sample_coords = sample_coords + flow.view(
|
| 108 |
+
b, 2, -1).permute(0, 2, 1).unsqueeze(-2) # [B, H*W, (2R+1)^2, 2]
|
| 109 |
+
else:
|
| 110 |
+
assert flow == 0.
|
| 111 |
+
|
| 112 |
+
# normalize coordinates to [-1, 1]
|
| 113 |
+
sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
|
| 114 |
+
window_feature = F.grid_sample(feature1, sample_coords_norm,
|
| 115 |
+
padding_mode=padding_mode, align_corners=True
|
| 116 |
+
).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2]
|
| 117 |
+
feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
|
| 118 |
+
|
| 119 |
+
corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2]
|
| 120 |
+
|
| 121 |
+
corr = corr.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous() # [B, (2R+1)^2, H, W]
|
| 122 |
+
|
| 123 |
+
return corr
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def global_correlation_softmax_stereo(feature0, feature1,
|
| 127 |
+
):
|
| 128 |
+
# global correlation on horizontal direction
|
| 129 |
+
b, c, h, w = feature0.shape
|
| 130 |
+
|
| 131 |
+
x_grid = torch.linspace(0, w - 1, w, device=feature0.device) # [W]
|
| 132 |
+
|
| 133 |
+
feature0 = feature0.permute(0, 2, 3, 1) # [B, H, W, C]
|
| 134 |
+
feature1 = feature1.permute(0, 2, 1, 3) # [B, H, C, W]
|
| 135 |
+
|
| 136 |
+
correlation = torch.matmul(feature0, feature1) / (c ** 0.5) # [B, H, W, W]
|
| 137 |
+
|
| 138 |
+
# mask subsequent positions to make disparity positive
|
| 139 |
+
mask = torch.triu(torch.ones((w, w)), diagonal=1).type_as(feature0) # [W, W]
|
| 140 |
+
valid_mask = (mask == 0).unsqueeze(0).unsqueeze(0).repeat(b, h, 1, 1) # [B, H, W, W]
|
| 141 |
+
|
| 142 |
+
correlation[~valid_mask] = -1e9
|
| 143 |
+
|
| 144 |
+
prob = F.softmax(correlation, dim=-1) # [B, H, W, W]
|
| 145 |
+
|
| 146 |
+
correspondence = (x_grid.view(1, 1, 1, w) * prob).sum(-1) # [B, H, W]
|
| 147 |
+
|
| 148 |
+
# NOTE: unlike flow, disparity is typically positive
|
| 149 |
+
disparity = x_grid.view(1, 1, w).repeat(b, h, 1) - correspondence # [B, H, W]
|
| 150 |
+
|
| 151 |
+
return disparity.unsqueeze(1), prob # feature resolution
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def local_correlation_softmax_stereo(feature0, feature1, local_radius,
|
| 155 |
+
):
|
| 156 |
+
b, c, h, w = feature0.size()
|
| 157 |
+
coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
|
| 158 |
+
coords = coords_init.view(b, 2, -1).permute(0, 2, 1).contiguous() # [B, H*W, 2]
|
| 159 |
+
|
| 160 |
+
local_h = 1
|
| 161 |
+
local_w = 2 * local_radius + 1
|
| 162 |
+
|
| 163 |
+
window_grid = generate_window_grid(0, 0,
|
| 164 |
+
-local_radius, local_radius,
|
| 165 |
+
local_h, local_w, device=feature0.device) # [1, 2R+1, 2]
|
| 166 |
+
window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1), 2]
|
| 167 |
+
sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1), 2]
|
| 168 |
+
|
| 169 |
+
sample_coords_softmax = sample_coords
|
| 170 |
+
|
| 171 |
+
# exclude coords that are out of image space
|
| 172 |
+
valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2]
|
| 173 |
+
valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2]
|
| 174 |
+
|
| 175 |
+
valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
|
| 176 |
+
|
| 177 |
+
# normalize coordinates to [-1, 1]
|
| 178 |
+
sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
|
| 179 |
+
window_feature = F.grid_sample(feature1, sample_coords_norm,
|
| 180 |
+
padding_mode='zeros', align_corners=True
|
| 181 |
+
).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)]
|
| 182 |
+
feature0_view = feature0.permute(0, 2, 3, 1).contiguous().view(b, h * w, 1, c) # [B, H*W, 1, C]
|
| 183 |
+
|
| 184 |
+
corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)]
|
| 185 |
+
|
| 186 |
+
# mask invalid locations
|
| 187 |
+
corr[~valid] = -1e9
|
| 188 |
+
|
| 189 |
+
prob = F.softmax(corr, -1) # [B, H*W, (2R+1)]
|
| 190 |
+
|
| 191 |
+
correspondence = torch.matmul(prob.unsqueeze(-2),
|
| 192 |
+
sample_coords_softmax).squeeze(-2).view(
|
| 193 |
+
b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W]
|
| 194 |
+
|
| 195 |
+
flow = correspondence - coords_init # flow at feature resolution
|
| 196 |
+
match_prob = prob
|
| 197 |
+
|
| 198 |
+
flow_x = -flow[:, :1] # [B, 1, H, W]
|
| 199 |
+
|
| 200 |
+
return flow_x, match_prob
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def correlation_softmax_depth(feature0, feature1,
|
| 204 |
+
intrinsics,
|
| 205 |
+
pose,
|
| 206 |
+
depth_candidates,
|
| 207 |
+
depth_from_argmax=False,
|
| 208 |
+
pred_bidir_depth=False,
|
| 209 |
+
):
|
| 210 |
+
b, c, h, w = feature0.size()
|
| 211 |
+
assert depth_candidates.dim() == 4 # [B, D, H, W]
|
| 212 |
+
scale_factor = c ** 0.5
|
| 213 |
+
|
| 214 |
+
if pred_bidir_depth:
|
| 215 |
+
feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0)
|
| 216 |
+
intrinsics = intrinsics.repeat(2, 1, 1)
|
| 217 |
+
pose = torch.cat((pose, torch.inverse(pose)), dim=0)
|
| 218 |
+
depth_candidates = depth_candidates.repeat(2, 1, 1, 1)
|
| 219 |
+
|
| 220 |
+
# depth candidates are actually inverse depth
|
| 221 |
+
warped_feature1 = warp_with_pose_depth_candidates(feature1, intrinsics, pose,
|
| 222 |
+
1. / depth_candidates,
|
| 223 |
+
) # [B, C, D, H, W]
|
| 224 |
+
|
| 225 |
+
correlation = (feature0.unsqueeze(2) * warped_feature1).sum(1) / scale_factor # [B, D, H, W]
|
| 226 |
+
|
| 227 |
+
match_prob = F.softmax(correlation, dim=1) # [B, D, H, W]
|
| 228 |
+
|
| 229 |
+
# for cross-task transfer (flow -> depth), extract depth with argmax at test time
|
| 230 |
+
if depth_from_argmax:
|
| 231 |
+
index = torch.argmax(match_prob, dim=1, keepdim=True)
|
| 232 |
+
depth = torch.gather(depth_candidates, dim=1, index=index)
|
| 233 |
+
else:
|
| 234 |
+
depth = (match_prob * depth_candidates).sum(dim=1, keepdim=True) # [B, 1, H, W]
|
| 235 |
+
|
| 236 |
+
return depth, match_prob
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def warp_with_pose_depth_candidates(feature1, intrinsics, pose, depth,
|
| 240 |
+
clamp_min_depth=1e-3,
|
| 241 |
+
):
|
| 242 |
+
"""
|
| 243 |
+
feature1: [B, C, H, W]
|
| 244 |
+
intrinsics: [B, 3, 3]
|
| 245 |
+
pose: [B, 4, 4]
|
| 246 |
+
depth: [B, D, H, W]
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
assert intrinsics.size(1) == intrinsics.size(2) == 3
|
| 250 |
+
assert pose.size(1) == pose.size(2) == 4
|
| 251 |
+
assert depth.dim() == 4
|
| 252 |
+
|
| 253 |
+
b, d, h, w = depth.size()
|
| 254 |
+
c = feature1.size(1)
|
| 255 |
+
|
| 256 |
+
with torch.no_grad():
|
| 257 |
+
# pixel coordinates
|
| 258 |
+
grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W]
|
| 259 |
+
# back project to 3D and transform viewpoint
|
| 260 |
+
points = torch.inverse(intrinsics).bmm(grid.view(b, 3, -1)) # [B, 3, H*W]
|
| 261 |
+
points = torch.bmm(pose[:, :3, :3], points).unsqueeze(2).repeat(
|
| 262 |
+
1, 1, d, 1) * depth.view(b, 1, d, h * w) # [B, 3, D, H*W]
|
| 263 |
+
points = points + pose[:, :3, -1:].unsqueeze(-1) # [B, 3, D, H*W]
|
| 264 |
+
# reproject to 2D image plane
|
| 265 |
+
points = torch.bmm(intrinsics, points.view(b, 3, -1)).view(b, 3, d, h * w) # [B, 3, D, H*W]
|
| 266 |
+
pixel_coords = points[:, :2] / points[:, -1:].clamp(min=clamp_min_depth) # [B, 2, D, H*W]
|
| 267 |
+
|
| 268 |
+
# normalize to [-1, 1]
|
| 269 |
+
x_grid = 2 * pixel_coords[:, 0] / (w - 1) - 1
|
| 270 |
+
y_grid = 2 * pixel_coords[:, 1] / (h - 1) - 1
|
| 271 |
+
|
| 272 |
+
grid = torch.stack([x_grid, y_grid], dim=-1) # [B, D, H*W, 2]
|
| 273 |
+
|
| 274 |
+
# sample features
|
| 275 |
+
warped_feature = F.grid_sample(feature1, grid.view(b, d * h, w, 2), mode='bilinear',
|
| 276 |
+
padding_mode='zeros',
|
| 277 |
+
align_corners=True).view(b, c, d, h, w) # [B, C, D, H, W]
|
| 278 |
+
|
| 279 |
+
return warped_feature
|
unimatch/position.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
# https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PositionEmbeddingSine(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
This is a more standard version of the position embedding, very similar to the one
|
| 12 |
+
used by the Attention is all you need paper, generalized to work on images.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.num_pos_feats = num_pos_feats
|
| 18 |
+
self.temperature = temperature
|
| 19 |
+
self.normalize = normalize
|
| 20 |
+
if scale is not None and normalize is False:
|
| 21 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 22 |
+
if scale is None:
|
| 23 |
+
scale = 2 * math.pi
|
| 24 |
+
self.scale = scale
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
# x = tensor_list.tensors # [B, C, H, W]
|
| 28 |
+
# mask = tensor_list.mask # [B, H, W], input with padding, valid as 0
|
| 29 |
+
b, c, h, w = x.size()
|
| 30 |
+
mask = torch.ones((b, h, w), device=x.device) # [B, H, W]
|
| 31 |
+
y_embed = mask.cumsum(1, dtype=torch.float32)
|
| 32 |
+
x_embed = mask.cumsum(2, dtype=torch.float32)
|
| 33 |
+
if self.normalize:
|
| 34 |
+
eps = 1e-6
|
| 35 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
| 36 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
| 37 |
+
|
| 38 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 39 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
| 40 |
+
|
| 41 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 42 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 43 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 44 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 45 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 46 |
+
return pos
|
unimatch/reg_refine.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class FlowHead(nn.Module):
|
| 7 |
+
def __init__(self, input_dim=128, hidden_dim=256,
|
| 8 |
+
out_dim=2,
|
| 9 |
+
):
|
| 10 |
+
super(FlowHead, self).__init__()
|
| 11 |
+
|
| 12 |
+
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
| 13 |
+
self.conv2 = nn.Conv2d(hidden_dim, out_dim, 3, padding=1)
|
| 14 |
+
self.relu = nn.ReLU(inplace=True)
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
out = self.conv2(self.relu(self.conv1(x)))
|
| 18 |
+
|
| 19 |
+
return out
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SepConvGRU(nn.Module):
|
| 23 |
+
def __init__(self, hidden_dim=128, input_dim=192 + 128,
|
| 24 |
+
kernel_size=5,
|
| 25 |
+
):
|
| 26 |
+
padding = (kernel_size - 1) // 2
|
| 27 |
+
|
| 28 |
+
super(SepConvGRU, self).__init__()
|
| 29 |
+
self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding))
|
| 30 |
+
self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding))
|
| 31 |
+
self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding))
|
| 32 |
+
|
| 33 |
+
self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0))
|
| 34 |
+
self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0))
|
| 35 |
+
self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0))
|
| 36 |
+
|
| 37 |
+
def forward(self, h, x):
|
| 38 |
+
# horizontal
|
| 39 |
+
hx = torch.cat([h, x], dim=1)
|
| 40 |
+
z = torch.sigmoid(self.convz1(hx))
|
| 41 |
+
r = torch.sigmoid(self.convr1(hx))
|
| 42 |
+
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
|
| 43 |
+
h = (1 - z) * h + z * q
|
| 44 |
+
|
| 45 |
+
# vertical
|
| 46 |
+
hx = torch.cat([h, x], dim=1)
|
| 47 |
+
z = torch.sigmoid(self.convz2(hx))
|
| 48 |
+
r = torch.sigmoid(self.convr2(hx))
|
| 49 |
+
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
|
| 50 |
+
h = (1 - z) * h + z * q
|
| 51 |
+
|
| 52 |
+
return h
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class BasicMotionEncoder(nn.Module):
|
| 56 |
+
def __init__(self, corr_channels=324,
|
| 57 |
+
flow_channels=2,
|
| 58 |
+
):
|
| 59 |
+
super(BasicMotionEncoder, self).__init__()
|
| 60 |
+
|
| 61 |
+
self.convc1 = nn.Conv2d(corr_channels, 256, 1, padding=0)
|
| 62 |
+
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
| 63 |
+
self.convf1 = nn.Conv2d(flow_channels, 128, 7, padding=3)
|
| 64 |
+
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
| 65 |
+
self.conv = nn.Conv2d(64 + 192, 128 - flow_channels, 3, padding=1)
|
| 66 |
+
|
| 67 |
+
def forward(self, flow, corr):
|
| 68 |
+
cor = F.relu(self.convc1(corr))
|
| 69 |
+
cor = F.relu(self.convc2(cor))
|
| 70 |
+
flo = F.relu(self.convf1(flow))
|
| 71 |
+
flo = F.relu(self.convf2(flo))
|
| 72 |
+
|
| 73 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
| 74 |
+
out = F.relu(self.conv(cor_flo))
|
| 75 |
+
return torch.cat([out, flow], dim=1)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class BasicUpdateBlock(nn.Module):
|
| 79 |
+
def __init__(self, corr_channels=324,
|
| 80 |
+
hidden_dim=128,
|
| 81 |
+
context_dim=128,
|
| 82 |
+
downsample_factor=8,
|
| 83 |
+
flow_dim=2,
|
| 84 |
+
bilinear_up=False,
|
| 85 |
+
):
|
| 86 |
+
super(BasicUpdateBlock, self).__init__()
|
| 87 |
+
|
| 88 |
+
self.encoder = BasicMotionEncoder(corr_channels=corr_channels,
|
| 89 |
+
flow_channels=flow_dim,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=context_dim + hidden_dim)
|
| 93 |
+
|
| 94 |
+
self.flow_head = FlowHead(hidden_dim, hidden_dim=256,
|
| 95 |
+
out_dim=flow_dim,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
if bilinear_up:
|
| 99 |
+
self.mask = None
|
| 100 |
+
else:
|
| 101 |
+
self.mask = nn.Sequential(
|
| 102 |
+
nn.Conv2d(hidden_dim, 256, 3, padding=1),
|
| 103 |
+
nn.ReLU(inplace=True),
|
| 104 |
+
nn.Conv2d(256, downsample_factor ** 2 * 9, 1, padding=0))
|
| 105 |
+
|
| 106 |
+
def forward(self, net, inp, corr, flow):
|
| 107 |
+
motion_features = self.encoder(flow, corr)
|
| 108 |
+
|
| 109 |
+
inp = torch.cat([inp, motion_features], dim=1)
|
| 110 |
+
|
| 111 |
+
net = self.gru(net, inp)
|
| 112 |
+
delta_flow = self.flow_head(net)
|
| 113 |
+
|
| 114 |
+
if self.mask is not None:
|
| 115 |
+
mask = self.mask(net)
|
| 116 |
+
else:
|
| 117 |
+
mask = None
|
| 118 |
+
|
| 119 |
+
return net, mask, delta_flow
|
unimatch/transformer.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from .attention import (single_head_full_attention, single_head_split_window_attention,
|
| 5 |
+
single_head_full_attention_1d, single_head_split_window_attention_1d)
|
| 6 |
+
from .utils import generate_shift_window_attn_mask, generate_shift_window_attn_mask_1d
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TransformerLayer(nn.Module):
|
| 10 |
+
def __init__(self,
|
| 11 |
+
d_model=128,
|
| 12 |
+
nhead=1,
|
| 13 |
+
no_ffn=False,
|
| 14 |
+
ffn_dim_expansion=4,
|
| 15 |
+
):
|
| 16 |
+
super(TransformerLayer, self).__init__()
|
| 17 |
+
|
| 18 |
+
self.dim = d_model
|
| 19 |
+
self.nhead = nhead
|
| 20 |
+
self.no_ffn = no_ffn
|
| 21 |
+
|
| 22 |
+
# multi-head attention
|
| 23 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
| 24 |
+
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
| 25 |
+
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
| 26 |
+
|
| 27 |
+
self.merge = nn.Linear(d_model, d_model, bias=False)
|
| 28 |
+
|
| 29 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 30 |
+
|
| 31 |
+
# no ffn after self-attn, with ffn after cross-attn
|
| 32 |
+
if not self.no_ffn:
|
| 33 |
+
in_channels = d_model * 2
|
| 34 |
+
self.mlp = nn.Sequential(
|
| 35 |
+
nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False),
|
| 36 |
+
nn.GELU(),
|
| 37 |
+
nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False),
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 41 |
+
|
| 42 |
+
def forward(self, source, target,
|
| 43 |
+
height=None,
|
| 44 |
+
width=None,
|
| 45 |
+
shifted_window_attn_mask=None,
|
| 46 |
+
shifted_window_attn_mask_1d=None,
|
| 47 |
+
attn_type='swin',
|
| 48 |
+
with_shift=False,
|
| 49 |
+
attn_num_splits=None,
|
| 50 |
+
):
|
| 51 |
+
# source, target: [B, L, C]
|
| 52 |
+
query, key, value = source, target, target
|
| 53 |
+
|
| 54 |
+
# for stereo: 2d attn in self-attn, 1d attn in cross-attn
|
| 55 |
+
is_self_attn = (query - key).abs().max() < 1e-6
|
| 56 |
+
|
| 57 |
+
# single-head attention
|
| 58 |
+
query = self.q_proj(query) # [B, L, C]
|
| 59 |
+
key = self.k_proj(key) # [B, L, C]
|
| 60 |
+
value = self.v_proj(value) # [B, L, C]
|
| 61 |
+
|
| 62 |
+
if attn_type == 'swin' and attn_num_splits > 1: # self, cross-attn: both swin 2d
|
| 63 |
+
if self.nhead > 1:
|
| 64 |
+
# we observe that multihead attention slows down the speed and increases the memory consumption
|
| 65 |
+
# without bringing obvious performance gains and thus the implementation is removed
|
| 66 |
+
raise NotImplementedError
|
| 67 |
+
else:
|
| 68 |
+
message = single_head_split_window_attention(query, key, value,
|
| 69 |
+
num_splits=attn_num_splits,
|
| 70 |
+
with_shift=with_shift,
|
| 71 |
+
h=height,
|
| 72 |
+
w=width,
|
| 73 |
+
attn_mask=shifted_window_attn_mask,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
elif attn_type == 'self_swin2d_cross_1d': # self-attn: swin 2d, cross-attn: full 1d
|
| 77 |
+
if self.nhead > 1:
|
| 78 |
+
raise NotImplementedError
|
| 79 |
+
else:
|
| 80 |
+
if is_self_attn:
|
| 81 |
+
if attn_num_splits > 1:
|
| 82 |
+
message = single_head_split_window_attention(query, key, value,
|
| 83 |
+
num_splits=attn_num_splits,
|
| 84 |
+
with_shift=with_shift,
|
| 85 |
+
h=height,
|
| 86 |
+
w=width,
|
| 87 |
+
attn_mask=shifted_window_attn_mask,
|
| 88 |
+
)
|
| 89 |
+
else:
|
| 90 |
+
# full 2d attn
|
| 91 |
+
message = single_head_full_attention(query, key, value) # [N, L, C]
|
| 92 |
+
|
| 93 |
+
else:
|
| 94 |
+
# cross attn 1d
|
| 95 |
+
message = single_head_full_attention_1d(query, key, value,
|
| 96 |
+
h=height,
|
| 97 |
+
w=width,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
elif attn_type == 'self_swin2d_cross_swin1d': # self-attn: swin 2d, cross-attn: swin 1d
|
| 101 |
+
if self.nhead > 1:
|
| 102 |
+
raise NotImplementedError
|
| 103 |
+
else:
|
| 104 |
+
if is_self_attn:
|
| 105 |
+
if attn_num_splits > 1:
|
| 106 |
+
# self attn shift window
|
| 107 |
+
message = single_head_split_window_attention(query, key, value,
|
| 108 |
+
num_splits=attn_num_splits,
|
| 109 |
+
with_shift=with_shift,
|
| 110 |
+
h=height,
|
| 111 |
+
w=width,
|
| 112 |
+
attn_mask=shifted_window_attn_mask,
|
| 113 |
+
)
|
| 114 |
+
else:
|
| 115 |
+
# full 2d attn
|
| 116 |
+
message = single_head_full_attention(query, key, value) # [N, L, C]
|
| 117 |
+
else:
|
| 118 |
+
if attn_num_splits > 1:
|
| 119 |
+
assert shifted_window_attn_mask_1d is not None
|
| 120 |
+
# cross attn 1d shift
|
| 121 |
+
message = single_head_split_window_attention_1d(query, key, value,
|
| 122 |
+
num_splits=attn_num_splits,
|
| 123 |
+
with_shift=with_shift,
|
| 124 |
+
h=height,
|
| 125 |
+
w=width,
|
| 126 |
+
attn_mask=shifted_window_attn_mask_1d,
|
| 127 |
+
)
|
| 128 |
+
else:
|
| 129 |
+
message = single_head_full_attention_1d(query, key, value,
|
| 130 |
+
h=height,
|
| 131 |
+
w=width,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
else:
|
| 135 |
+
message = single_head_full_attention(query, key, value) # [B, L, C]
|
| 136 |
+
|
| 137 |
+
message = self.merge(message) # [B, L, C]
|
| 138 |
+
message = self.norm1(message)
|
| 139 |
+
|
| 140 |
+
if not self.no_ffn:
|
| 141 |
+
message = self.mlp(torch.cat([source, message], dim=-1))
|
| 142 |
+
message = self.norm2(message)
|
| 143 |
+
|
| 144 |
+
return source + message
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class TransformerBlock(nn.Module):
|
| 148 |
+
"""self attention + cross attention + FFN"""
|
| 149 |
+
|
| 150 |
+
def __init__(self,
|
| 151 |
+
d_model=128,
|
| 152 |
+
nhead=1,
|
| 153 |
+
ffn_dim_expansion=4,
|
| 154 |
+
):
|
| 155 |
+
super(TransformerBlock, self).__init__()
|
| 156 |
+
|
| 157 |
+
self.self_attn = TransformerLayer(d_model=d_model,
|
| 158 |
+
nhead=nhead,
|
| 159 |
+
no_ffn=True,
|
| 160 |
+
ffn_dim_expansion=ffn_dim_expansion,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
self.cross_attn_ffn = TransformerLayer(d_model=d_model,
|
| 164 |
+
nhead=nhead,
|
| 165 |
+
ffn_dim_expansion=ffn_dim_expansion,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def forward(self, source, target,
|
| 169 |
+
height=None,
|
| 170 |
+
width=None,
|
| 171 |
+
shifted_window_attn_mask=None,
|
| 172 |
+
shifted_window_attn_mask_1d=None,
|
| 173 |
+
attn_type='swin',
|
| 174 |
+
with_shift=False,
|
| 175 |
+
attn_num_splits=None,
|
| 176 |
+
):
|
| 177 |
+
# source, target: [B, L, C]
|
| 178 |
+
|
| 179 |
+
# self attention
|
| 180 |
+
source = self.self_attn(source, source,
|
| 181 |
+
height=height,
|
| 182 |
+
width=width,
|
| 183 |
+
shifted_window_attn_mask=shifted_window_attn_mask,
|
| 184 |
+
attn_type=attn_type,
|
| 185 |
+
with_shift=with_shift,
|
| 186 |
+
attn_num_splits=attn_num_splits,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# cross attention and ffn
|
| 190 |
+
source = self.cross_attn_ffn(source, target,
|
| 191 |
+
height=height,
|
| 192 |
+
width=width,
|
| 193 |
+
shifted_window_attn_mask=shifted_window_attn_mask,
|
| 194 |
+
shifted_window_attn_mask_1d=shifted_window_attn_mask_1d,
|
| 195 |
+
attn_type=attn_type,
|
| 196 |
+
with_shift=with_shift,
|
| 197 |
+
attn_num_splits=attn_num_splits,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
return source
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class FeatureTransformer(nn.Module):
|
| 204 |
+
def __init__(self,
|
| 205 |
+
num_layers=6,
|
| 206 |
+
d_model=128,
|
| 207 |
+
nhead=1,
|
| 208 |
+
ffn_dim_expansion=4,
|
| 209 |
+
):
|
| 210 |
+
super(FeatureTransformer, self).__init__()
|
| 211 |
+
|
| 212 |
+
self.d_model = d_model
|
| 213 |
+
self.nhead = nhead
|
| 214 |
+
|
| 215 |
+
self.layers = nn.ModuleList([
|
| 216 |
+
TransformerBlock(d_model=d_model,
|
| 217 |
+
nhead=nhead,
|
| 218 |
+
ffn_dim_expansion=ffn_dim_expansion,
|
| 219 |
+
)
|
| 220 |
+
for i in range(num_layers)])
|
| 221 |
+
|
| 222 |
+
for p in self.parameters():
|
| 223 |
+
if p.dim() > 1:
|
| 224 |
+
nn.init.xavier_uniform_(p)
|
| 225 |
+
|
| 226 |
+
def forward(self, feature0, feature1,
|
| 227 |
+
attn_type='swin',
|
| 228 |
+
attn_num_splits=None,
|
| 229 |
+
**kwargs,
|
| 230 |
+
):
|
| 231 |
+
|
| 232 |
+
b, c, h, w = feature0.shape
|
| 233 |
+
assert self.d_model == c
|
| 234 |
+
|
| 235 |
+
feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
|
| 236 |
+
feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
|
| 237 |
+
|
| 238 |
+
# 2d attention
|
| 239 |
+
if 'swin' in attn_type and attn_num_splits > 1:
|
| 240 |
+
# global and refine use different number of splits
|
| 241 |
+
window_size_h = h // attn_num_splits
|
| 242 |
+
window_size_w = w // attn_num_splits
|
| 243 |
+
|
| 244 |
+
# compute attn mask once
|
| 245 |
+
shifted_window_attn_mask = generate_shift_window_attn_mask(
|
| 246 |
+
input_resolution=(h, w),
|
| 247 |
+
window_size_h=window_size_h,
|
| 248 |
+
window_size_w=window_size_w,
|
| 249 |
+
shift_size_h=window_size_h // 2,
|
| 250 |
+
shift_size_w=window_size_w // 2,
|
| 251 |
+
device=feature0.device,
|
| 252 |
+
) # [K*K, H/K*W/K, H/K*W/K]
|
| 253 |
+
else:
|
| 254 |
+
shifted_window_attn_mask = None
|
| 255 |
+
|
| 256 |
+
# 1d attention
|
| 257 |
+
if 'swin1d' in attn_type and attn_num_splits > 1:
|
| 258 |
+
window_size_w = w // attn_num_splits
|
| 259 |
+
|
| 260 |
+
# compute attn mask once
|
| 261 |
+
shifted_window_attn_mask_1d = generate_shift_window_attn_mask_1d(
|
| 262 |
+
input_w=w,
|
| 263 |
+
window_size_w=window_size_w,
|
| 264 |
+
shift_size_w=window_size_w // 2,
|
| 265 |
+
device=feature0.device,
|
| 266 |
+
) # [K, W/K, W/K]
|
| 267 |
+
else:
|
| 268 |
+
shifted_window_attn_mask_1d = None
|
| 269 |
+
|
| 270 |
+
# concat feature0 and feature1 in batch dimension to compute in parallel
|
| 271 |
+
concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C]
|
| 272 |
+
concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C]
|
| 273 |
+
|
| 274 |
+
for i, layer in enumerate(self.layers):
|
| 275 |
+
concat0 = layer(concat0, concat1,
|
| 276 |
+
height=h,
|
| 277 |
+
width=w,
|
| 278 |
+
attn_type=attn_type,
|
| 279 |
+
with_shift='swin' in attn_type and attn_num_splits > 1 and i % 2 == 1,
|
| 280 |
+
attn_num_splits=attn_num_splits,
|
| 281 |
+
shifted_window_attn_mask=shifted_window_attn_mask,
|
| 282 |
+
shifted_window_attn_mask_1d=shifted_window_attn_mask_1d,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# update feature1
|
| 286 |
+
concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0)
|
| 287 |
+
|
| 288 |
+
feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C]
|
| 289 |
+
|
| 290 |
+
# reshape back
|
| 291 |
+
feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
|
| 292 |
+
feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
|
| 293 |
+
|
| 294 |
+
return feature0, feature1
|
unimatch/trident_conv.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
from torch.nn.modules.utils import _pair
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MultiScaleTridentConv(nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
in_channels,
|
| 14 |
+
out_channels,
|
| 15 |
+
kernel_size,
|
| 16 |
+
stride=1,
|
| 17 |
+
strides=1,
|
| 18 |
+
paddings=0,
|
| 19 |
+
dilations=1,
|
| 20 |
+
dilation=1,
|
| 21 |
+
groups=1,
|
| 22 |
+
num_branch=1,
|
| 23 |
+
test_branch_idx=-1,
|
| 24 |
+
bias=False,
|
| 25 |
+
norm=None,
|
| 26 |
+
activation=None,
|
| 27 |
+
):
|
| 28 |
+
super(MultiScaleTridentConv, self).__init__()
|
| 29 |
+
self.in_channels = in_channels
|
| 30 |
+
self.out_channels = out_channels
|
| 31 |
+
self.kernel_size = _pair(kernel_size)
|
| 32 |
+
self.num_branch = num_branch
|
| 33 |
+
self.stride = _pair(stride)
|
| 34 |
+
self.groups = groups
|
| 35 |
+
self.with_bias = bias
|
| 36 |
+
self.dilation = dilation
|
| 37 |
+
if isinstance(paddings, int):
|
| 38 |
+
paddings = [paddings] * self.num_branch
|
| 39 |
+
if isinstance(dilations, int):
|
| 40 |
+
dilations = [dilations] * self.num_branch
|
| 41 |
+
if isinstance(strides, int):
|
| 42 |
+
strides = [strides] * self.num_branch
|
| 43 |
+
self.paddings = [_pair(padding) for padding in paddings]
|
| 44 |
+
self.dilations = [_pair(dilation) for dilation in dilations]
|
| 45 |
+
self.strides = [_pair(stride) for stride in strides]
|
| 46 |
+
self.test_branch_idx = test_branch_idx
|
| 47 |
+
self.norm = norm
|
| 48 |
+
self.activation = activation
|
| 49 |
+
|
| 50 |
+
assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
|
| 51 |
+
|
| 52 |
+
self.weight = nn.Parameter(
|
| 53 |
+
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
|
| 54 |
+
)
|
| 55 |
+
if bias:
|
| 56 |
+
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
| 57 |
+
else:
|
| 58 |
+
self.bias = None
|
| 59 |
+
|
| 60 |
+
nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
|
| 61 |
+
if self.bias is not None:
|
| 62 |
+
nn.init.constant_(self.bias, 0)
|
| 63 |
+
|
| 64 |
+
def forward(self, inputs):
|
| 65 |
+
num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
|
| 66 |
+
assert len(inputs) == num_branch
|
| 67 |
+
|
| 68 |
+
if self.training or self.test_branch_idx == -1:
|
| 69 |
+
outputs = [
|
| 70 |
+
F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups)
|
| 71 |
+
for input, stride, padding in zip(inputs, self.strides, self.paddings)
|
| 72 |
+
]
|
| 73 |
+
else:
|
| 74 |
+
outputs = [
|
| 75 |
+
F.conv2d(
|
| 76 |
+
inputs[0],
|
| 77 |
+
self.weight,
|
| 78 |
+
self.bias,
|
| 79 |
+
self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1],
|
| 80 |
+
self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1],
|
| 81 |
+
self.dilation,
|
| 82 |
+
self.groups,
|
| 83 |
+
)
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
if self.norm is not None:
|
| 87 |
+
outputs = [self.norm(x) for x in outputs]
|
| 88 |
+
if self.activation is not None:
|
| 89 |
+
outputs = [self.activation(x) for x in outputs]
|
| 90 |
+
return outputs
|
unimatch/unimatch.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from .backbone import CNNEncoder
|
| 6 |
+
from .transformer import FeatureTransformer
|
| 7 |
+
from .matching import (global_correlation_softmax, local_correlation_softmax, local_correlation_with_flow,
|
| 8 |
+
global_correlation_softmax_stereo, local_correlation_softmax_stereo,
|
| 9 |
+
correlation_softmax_depth)
|
| 10 |
+
from .attention import SelfAttnPropagation
|
| 11 |
+
from .geometry import flow_warp, compute_flow_with_depth_pose
|
| 12 |
+
from .reg_refine import BasicUpdateBlock
|
| 13 |
+
from .utils import normalize_img, feature_add_position, upsample_flow_with_mask
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class UniMatch(nn.Module):
|
| 17 |
+
def __init__(self,
|
| 18 |
+
num_scales=1,
|
| 19 |
+
feature_channels=128,
|
| 20 |
+
upsample_factor=8,
|
| 21 |
+
num_head=1,
|
| 22 |
+
ffn_dim_expansion=4,
|
| 23 |
+
num_transformer_layers=6,
|
| 24 |
+
reg_refine=False, # optional local regression refinement
|
| 25 |
+
task='flow',
|
| 26 |
+
):
|
| 27 |
+
super(UniMatch, self).__init__()
|
| 28 |
+
|
| 29 |
+
self.feature_channels = feature_channels
|
| 30 |
+
self.num_scales = num_scales
|
| 31 |
+
self.upsample_factor = upsample_factor
|
| 32 |
+
self.reg_refine = reg_refine
|
| 33 |
+
|
| 34 |
+
# CNN
|
| 35 |
+
self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales)
|
| 36 |
+
|
| 37 |
+
# Transformer
|
| 38 |
+
self.transformer = FeatureTransformer(num_layers=num_transformer_layers,
|
| 39 |
+
d_model=feature_channels,
|
| 40 |
+
nhead=num_head,
|
| 41 |
+
ffn_dim_expansion=ffn_dim_expansion,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# propagation with self-attn
|
| 45 |
+
self.feature_flow_attn = SelfAttnPropagation(in_channels=feature_channels)
|
| 46 |
+
|
| 47 |
+
if not self.reg_refine or task == 'depth':
|
| 48 |
+
# convex upsampling simiar to RAFT
|
| 49 |
+
# concat feature0 and low res flow as input
|
| 50 |
+
self.upsampler = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1),
|
| 51 |
+
nn.ReLU(inplace=True),
|
| 52 |
+
nn.Conv2d(256, upsample_factor ** 2 * 9, 1, 1, 0))
|
| 53 |
+
# thus far, all the learnable parameters are task-agnostic
|
| 54 |
+
|
| 55 |
+
if reg_refine:
|
| 56 |
+
# optional task-specific local regression refinement
|
| 57 |
+
self.refine_proj = nn.Conv2d(128, 256, 1)
|
| 58 |
+
self.refine = BasicUpdateBlock(corr_channels=(2 * 4 + 1) ** 2,
|
| 59 |
+
downsample_factor=upsample_factor,
|
| 60 |
+
flow_dim=2 if task == 'flow' else 1,
|
| 61 |
+
bilinear_up=task == 'depth',
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def extract_feature(self, img0, img1):
|
| 65 |
+
concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W]
|
| 66 |
+
features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low
|
| 67 |
+
|
| 68 |
+
# reverse: resolution from low to high
|
| 69 |
+
features = features[::-1]
|
| 70 |
+
|
| 71 |
+
feature0, feature1 = [], []
|
| 72 |
+
|
| 73 |
+
for i in range(len(features)):
|
| 74 |
+
feature = features[i]
|
| 75 |
+
chunks = torch.chunk(feature, 2, 0) # tuple
|
| 76 |
+
feature0.append(chunks[0])
|
| 77 |
+
feature1.append(chunks[1])
|
| 78 |
+
|
| 79 |
+
return feature0, feature1
|
| 80 |
+
|
| 81 |
+
def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8,
|
| 82 |
+
is_depth=False):
|
| 83 |
+
if bilinear:
|
| 84 |
+
multiplier = 1 if is_depth else upsample_factor
|
| 85 |
+
up_flow = F.interpolate(flow, scale_factor=upsample_factor,
|
| 86 |
+
mode='bilinear', align_corners=True) * multiplier
|
| 87 |
+
else:
|
| 88 |
+
concat = torch.cat((flow, feature), dim=1)
|
| 89 |
+
mask = self.upsampler(concat)
|
| 90 |
+
up_flow = upsample_flow_with_mask(flow, mask, upsample_factor=self.upsample_factor,
|
| 91 |
+
is_depth=is_depth)
|
| 92 |
+
|
| 93 |
+
return up_flow
|
| 94 |
+
|
| 95 |
+
def forward(self, img0, img1,
|
| 96 |
+
attn_type=None,
|
| 97 |
+
attn_splits_list=None,
|
| 98 |
+
corr_radius_list=None,
|
| 99 |
+
prop_radius_list=None,
|
| 100 |
+
num_reg_refine=1,
|
| 101 |
+
pred_bidir_flow=False,
|
| 102 |
+
task='flow',
|
| 103 |
+
intrinsics=None,
|
| 104 |
+
pose=None, # relative pose transform
|
| 105 |
+
min_depth=1. / 0.5, # inverse depth range
|
| 106 |
+
max_depth=1. / 10,
|
| 107 |
+
num_depth_candidates=64,
|
| 108 |
+
depth_from_argmax=False,
|
| 109 |
+
pred_bidir_depth=False,
|
| 110 |
+
**kwargs,
|
| 111 |
+
):
|
| 112 |
+
|
| 113 |
+
if pred_bidir_flow:
|
| 114 |
+
assert task == 'flow'
|
| 115 |
+
|
| 116 |
+
if task == 'depth':
|
| 117 |
+
assert self.num_scales == 1 # multi-scale depth model is not supported yet
|
| 118 |
+
|
| 119 |
+
results_dict = {}
|
| 120 |
+
flow_preds = []
|
| 121 |
+
|
| 122 |
+
if task == 'flow':
|
| 123 |
+
# stereo and depth tasks have normalized img in dataloader
|
| 124 |
+
img0, img1 = normalize_img(img0, img1) # [B, 3, H, W]
|
| 125 |
+
|
| 126 |
+
# list of features, resolution low to high
|
| 127 |
+
feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features
|
| 128 |
+
|
| 129 |
+
flow = None
|
| 130 |
+
|
| 131 |
+
if task != 'depth':
|
| 132 |
+
assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales
|
| 133 |
+
else:
|
| 134 |
+
assert len(attn_splits_list) == len(prop_radius_list) == self.num_scales == 1
|
| 135 |
+
|
| 136 |
+
for scale_idx in range(self.num_scales):
|
| 137 |
+
feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
|
| 138 |
+
|
| 139 |
+
if pred_bidir_flow and scale_idx > 0:
|
| 140 |
+
# predicting bidirectional flow with refinement
|
| 141 |
+
feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0)
|
| 142 |
+
|
| 143 |
+
feature0_ori, feature1_ori = feature0, feature1
|
| 144 |
+
|
| 145 |
+
upsample_factor = self.upsample_factor * (2 ** (self.num_scales - 1 - scale_idx))
|
| 146 |
+
|
| 147 |
+
if task == 'depth':
|
| 148 |
+
# scale intrinsics
|
| 149 |
+
intrinsics_curr = intrinsics.clone()
|
| 150 |
+
intrinsics_curr[:, :2] = intrinsics_curr[:, :2] / upsample_factor
|
| 151 |
+
|
| 152 |
+
if scale_idx > 0:
|
| 153 |
+
assert task != 'depth' # not supported for multi-scale depth model
|
| 154 |
+
flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2
|
| 155 |
+
|
| 156 |
+
if flow is not None:
|
| 157 |
+
assert task != 'depth'
|
| 158 |
+
flow = flow.detach()
|
| 159 |
+
|
| 160 |
+
if task == 'stereo':
|
| 161 |
+
# construct flow vector for disparity
|
| 162 |
+
# flow here is actually disparity
|
| 163 |
+
zeros = torch.zeros_like(flow) # [B, 1, H, W]
|
| 164 |
+
# NOTE: reverse disp, disparity is positive
|
| 165 |
+
displace = torch.cat((-flow, zeros), dim=1) # [B, 2, H, W]
|
| 166 |
+
feature1 = flow_warp(feature1, displace) # [B, C, H, W]
|
| 167 |
+
elif task == 'flow':
|
| 168 |
+
feature1 = flow_warp(feature1, flow) # [B, C, H, W]
|
| 169 |
+
else:
|
| 170 |
+
raise NotImplementedError
|
| 171 |
+
|
| 172 |
+
attn_splits = attn_splits_list[scale_idx]
|
| 173 |
+
if task != 'depth':
|
| 174 |
+
corr_radius = corr_radius_list[scale_idx]
|
| 175 |
+
prop_radius = prop_radius_list[scale_idx]
|
| 176 |
+
|
| 177 |
+
# add position to features
|
| 178 |
+
feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels)
|
| 179 |
+
|
| 180 |
+
# Transformer
|
| 181 |
+
feature0, feature1 = self.transformer(feature0, feature1,
|
| 182 |
+
attn_type=attn_type,
|
| 183 |
+
attn_num_splits=attn_splits,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# correlation and softmax
|
| 187 |
+
if task == 'depth':
|
| 188 |
+
# first generate depth candidates
|
| 189 |
+
b, _, h, w = feature0.size()
|
| 190 |
+
depth_candidates = torch.linspace(min_depth, max_depth, num_depth_candidates).type_as(feature0)
|
| 191 |
+
depth_candidates = depth_candidates.view(1, num_depth_candidates, 1, 1).repeat(b, 1, h,
|
| 192 |
+
w) # [B, D, H, W]
|
| 193 |
+
|
| 194 |
+
flow_pred = correlation_softmax_depth(feature0, feature1,
|
| 195 |
+
intrinsics_curr,
|
| 196 |
+
pose,
|
| 197 |
+
depth_candidates=depth_candidates,
|
| 198 |
+
depth_from_argmax=depth_from_argmax,
|
| 199 |
+
pred_bidir_depth=pred_bidir_depth,
|
| 200 |
+
)[0]
|
| 201 |
+
|
| 202 |
+
else:
|
| 203 |
+
if corr_radius == -1: # global matching
|
| 204 |
+
if task == 'flow':
|
| 205 |
+
flow_pred = global_correlation_softmax(feature0, feature1, pred_bidir_flow)[0]
|
| 206 |
+
elif task == 'stereo':
|
| 207 |
+
flow_pred = global_correlation_softmax_stereo(feature0, feature1)[0]
|
| 208 |
+
else:
|
| 209 |
+
raise NotImplementedError
|
| 210 |
+
else: # local matching
|
| 211 |
+
if task == 'flow':
|
| 212 |
+
flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0]
|
| 213 |
+
elif task == 'stereo':
|
| 214 |
+
flow_pred = local_correlation_softmax_stereo(feature0, feature1, corr_radius)[0]
|
| 215 |
+
else:
|
| 216 |
+
raise NotImplementedError
|
| 217 |
+
|
| 218 |
+
# flow or residual flow
|
| 219 |
+
flow = flow + flow_pred if flow is not None else flow_pred
|
| 220 |
+
|
| 221 |
+
if task == 'stereo':
|
| 222 |
+
flow = flow.clamp(min=0) # positive disparity
|
| 223 |
+
|
| 224 |
+
# upsample to the original resolution for supervison at training time only
|
| 225 |
+
if self.training:
|
| 226 |
+
flow_bilinear = self.upsample_flow(flow, None, bilinear=True, upsample_factor=upsample_factor,
|
| 227 |
+
is_depth=task == 'depth')
|
| 228 |
+
flow_preds.append(flow_bilinear)
|
| 229 |
+
|
| 230 |
+
# flow propagation with self-attn
|
| 231 |
+
if (pred_bidir_flow or pred_bidir_depth) and scale_idx == 0:
|
| 232 |
+
feature0 = torch.cat((feature0, feature1), dim=0) # [2*B, C, H, W] for propagation
|
| 233 |
+
|
| 234 |
+
flow = self.feature_flow_attn(feature0, flow.detach(),
|
| 235 |
+
local_window_attn=prop_radius > 0,
|
| 236 |
+
local_window_radius=prop_radius,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# bilinear exclude the last one
|
| 240 |
+
if self.training and scale_idx < self.num_scales - 1:
|
| 241 |
+
flow_up = self.upsample_flow(flow, feature0, bilinear=True,
|
| 242 |
+
upsample_factor=upsample_factor,
|
| 243 |
+
is_depth=task == 'depth')
|
| 244 |
+
flow_preds.append(flow_up)
|
| 245 |
+
|
| 246 |
+
if scale_idx == self.num_scales - 1:
|
| 247 |
+
if not self.reg_refine:
|
| 248 |
+
# upsample to the original image resolution
|
| 249 |
+
|
| 250 |
+
if task == 'stereo':
|
| 251 |
+
flow_pad = torch.cat((-flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W]
|
| 252 |
+
flow_up_pad = self.upsample_flow(flow_pad, feature0)
|
| 253 |
+
flow_up = -flow_up_pad[:, :1] # [B, 1, H, W]
|
| 254 |
+
elif task == 'depth':
|
| 255 |
+
depth_pad = torch.cat((flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W]
|
| 256 |
+
depth_up_pad = self.upsample_flow(depth_pad, feature0,
|
| 257 |
+
is_depth=True).clamp(min=min_depth, max=max_depth)
|
| 258 |
+
flow_up = depth_up_pad[:, :1] # [B, 1, H, W]
|
| 259 |
+
else:
|
| 260 |
+
flow_up = self.upsample_flow(flow, feature0)
|
| 261 |
+
|
| 262 |
+
flow_preds.append(flow_up)
|
| 263 |
+
else:
|
| 264 |
+
# task-specific local regression refinement
|
| 265 |
+
# supervise current flow
|
| 266 |
+
if self.training:
|
| 267 |
+
flow_up = self.upsample_flow(flow, feature0, bilinear=True,
|
| 268 |
+
upsample_factor=upsample_factor,
|
| 269 |
+
is_depth=task == 'depth')
|
| 270 |
+
flow_preds.append(flow_up)
|
| 271 |
+
|
| 272 |
+
assert num_reg_refine > 0
|
| 273 |
+
for refine_iter_idx in range(num_reg_refine):
|
| 274 |
+
flow = flow.detach()
|
| 275 |
+
|
| 276 |
+
if task == 'stereo':
|
| 277 |
+
zeros = torch.zeros_like(flow) # [B, 1, H, W]
|
| 278 |
+
# NOTE: reverse disp, disparity is positive
|
| 279 |
+
displace = torch.cat((-flow, zeros), dim=1) # [B, 2, H, W]
|
| 280 |
+
correlation = local_correlation_with_flow(
|
| 281 |
+
feature0_ori,
|
| 282 |
+
feature1_ori,
|
| 283 |
+
flow=displace,
|
| 284 |
+
local_radius=4,
|
| 285 |
+
) # [B, (2R+1)^2, H, W]
|
| 286 |
+
elif task == 'depth':
|
| 287 |
+
if pred_bidir_depth and refine_iter_idx == 0:
|
| 288 |
+
intrinsics_curr = intrinsics_curr.repeat(2, 1, 1)
|
| 289 |
+
pose = torch.cat((pose, torch.inverse(pose)), dim=0)
|
| 290 |
+
|
| 291 |
+
feature0_ori, feature1_ori = torch.cat((feature0_ori, feature1_ori),
|
| 292 |
+
dim=0), torch.cat((feature1_ori,
|
| 293 |
+
feature0_ori), dim=0)
|
| 294 |
+
|
| 295 |
+
flow_from_depth = compute_flow_with_depth_pose(1. / flow.squeeze(1),
|
| 296 |
+
intrinsics_curr,
|
| 297 |
+
extrinsics_rel=pose,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
correlation = local_correlation_with_flow(
|
| 301 |
+
feature0_ori,
|
| 302 |
+
feature1_ori,
|
| 303 |
+
flow=flow_from_depth,
|
| 304 |
+
local_radius=4,
|
| 305 |
+
) # [B, (2R+1)^2, H, W]
|
| 306 |
+
|
| 307 |
+
else:
|
| 308 |
+
correlation = local_correlation_with_flow(
|
| 309 |
+
feature0_ori,
|
| 310 |
+
feature1_ori,
|
| 311 |
+
flow=flow,
|
| 312 |
+
local_radius=4,
|
| 313 |
+
) # [B, (2R+1)^2, H, W]
|
| 314 |
+
|
| 315 |
+
proj = self.refine_proj(feature0)
|
| 316 |
+
|
| 317 |
+
net, inp = torch.chunk(proj, chunks=2, dim=1)
|
| 318 |
+
|
| 319 |
+
net = torch.tanh(net)
|
| 320 |
+
inp = torch.relu(inp)
|
| 321 |
+
|
| 322 |
+
net, up_mask, residual_flow = self.refine(net, inp, correlation, flow.clone(),
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
if task == 'depth':
|
| 326 |
+
flow = (flow - residual_flow).clamp(min=min_depth, max=max_depth)
|
| 327 |
+
else:
|
| 328 |
+
flow = flow + residual_flow
|
| 329 |
+
|
| 330 |
+
if task == 'stereo':
|
| 331 |
+
flow = flow.clamp(min=0) # positive
|
| 332 |
+
|
| 333 |
+
if self.training or refine_iter_idx == num_reg_refine - 1:
|
| 334 |
+
if task == 'depth':
|
| 335 |
+
if refine_iter_idx < num_reg_refine - 1:
|
| 336 |
+
# bilinear upsampling
|
| 337 |
+
flow_up = self.upsample_flow(flow, feature0, bilinear=True,
|
| 338 |
+
upsample_factor=upsample_factor,
|
| 339 |
+
is_depth=True)
|
| 340 |
+
else:
|
| 341 |
+
# last one convex upsampling
|
| 342 |
+
# NOTE: clamp depth due to the zero padding in the unfold in the convex upsampling
|
| 343 |
+
# pad depth to 2 channels as flow
|
| 344 |
+
depth_pad = torch.cat((flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W]
|
| 345 |
+
depth_up_pad = self.upsample_flow(depth_pad, feature0,
|
| 346 |
+
is_depth=True).clamp(min=min_depth,
|
| 347 |
+
max=max_depth)
|
| 348 |
+
flow_up = depth_up_pad[:, :1] # [B, 1, H, W]
|
| 349 |
+
|
| 350 |
+
else:
|
| 351 |
+
flow_up = upsample_flow_with_mask(flow, up_mask, upsample_factor=self.upsample_factor,
|
| 352 |
+
is_depth=task == 'depth')
|
| 353 |
+
|
| 354 |
+
flow_preds.append(flow_up)
|
| 355 |
+
|
| 356 |
+
if task == 'stereo':
|
| 357 |
+
for i in range(len(flow_preds)):
|
| 358 |
+
flow_preds[i] = flow_preds[i].squeeze(1) # [B, H, W]
|
| 359 |
+
|
| 360 |
+
# convert inverse depth to depth
|
| 361 |
+
if task == 'depth':
|
| 362 |
+
for i in range(len(flow_preds)):
|
| 363 |
+
flow_preds[i] = 1. / flow_preds[i].squeeze(1) # [B, H, W]
|
| 364 |
+
|
| 365 |
+
results_dict.update({'flow_preds': flow_preds})
|
| 366 |
+
|
| 367 |
+
return results_dict
|
unimatch/utils.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from .position import PositionEmbeddingSine
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
|
| 7 |
+
assert device is not None
|
| 8 |
+
|
| 9 |
+
x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device),
|
| 10 |
+
torch.linspace(h_min, h_max, len_h, device=device)],
|
| 11 |
+
)
|
| 12 |
+
grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
|
| 13 |
+
|
| 14 |
+
return grid
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def normalize_coords(coords, h, w):
|
| 18 |
+
# coords: [B, H, W, 2]
|
| 19 |
+
c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
|
| 20 |
+
return (coords - c) / c # [-1, 1]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def normalize_img(img0, img1):
|
| 24 |
+
# loaded images are in [0, 255]
|
| 25 |
+
# normalize by ImageNet mean and std
|
| 26 |
+
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device)
|
| 27 |
+
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device)
|
| 28 |
+
img0 = (img0 / 255. - mean) / std
|
| 29 |
+
img1 = (img1 / 255. - mean) / std
|
| 30 |
+
|
| 31 |
+
return img0, img1
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def split_feature(feature,
|
| 35 |
+
num_splits=2,
|
| 36 |
+
channel_last=False,
|
| 37 |
+
):
|
| 38 |
+
if channel_last: # [B, H, W, C]
|
| 39 |
+
b, h, w, c = feature.size()
|
| 40 |
+
assert h % num_splits == 0 and w % num_splits == 0
|
| 41 |
+
|
| 42 |
+
b_new = b * num_splits * num_splits
|
| 43 |
+
h_new = h // num_splits
|
| 44 |
+
w_new = w // num_splits
|
| 45 |
+
|
| 46 |
+
feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c
|
| 47 |
+
).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C]
|
| 48 |
+
else: # [B, C, H, W]
|
| 49 |
+
b, c, h, w = feature.size()
|
| 50 |
+
assert h % num_splits == 0 and w % num_splits == 0
|
| 51 |
+
|
| 52 |
+
b_new = b * num_splits * num_splits
|
| 53 |
+
h_new = h // num_splits
|
| 54 |
+
w_new = w // num_splits
|
| 55 |
+
|
| 56 |
+
feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits
|
| 57 |
+
).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K]
|
| 58 |
+
|
| 59 |
+
return feature
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def merge_splits(splits,
|
| 63 |
+
num_splits=2,
|
| 64 |
+
channel_last=False,
|
| 65 |
+
):
|
| 66 |
+
if channel_last: # [B*K*K, H/K, W/K, C]
|
| 67 |
+
b, h, w, c = splits.size()
|
| 68 |
+
new_b = b // num_splits // num_splits
|
| 69 |
+
|
| 70 |
+
splits = splits.view(new_b, num_splits, num_splits, h, w, c)
|
| 71 |
+
merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view(
|
| 72 |
+
new_b, num_splits * h, num_splits * w, c) # [B, H, W, C]
|
| 73 |
+
else: # [B*K*K, C, H/K, W/K]
|
| 74 |
+
b, c, h, w = splits.size()
|
| 75 |
+
new_b = b // num_splits // num_splits
|
| 76 |
+
|
| 77 |
+
splits = splits.view(new_b, num_splits, num_splits, c, h, w)
|
| 78 |
+
merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view(
|
| 79 |
+
new_b, c, num_splits * h, num_splits * w) # [B, C, H, W]
|
| 80 |
+
|
| 81 |
+
return merge
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w,
|
| 85 |
+
shift_size_h, shift_size_w, device=torch.device('cuda')):
|
| 86 |
+
# ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
|
| 87 |
+
# calculate attention mask for SW-MSA
|
| 88 |
+
h, w = input_resolution
|
| 89 |
+
img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1
|
| 90 |
+
h_slices = (slice(0, -window_size_h),
|
| 91 |
+
slice(-window_size_h, -shift_size_h),
|
| 92 |
+
slice(-shift_size_h, None))
|
| 93 |
+
w_slices = (slice(0, -window_size_w),
|
| 94 |
+
slice(-window_size_w, -shift_size_w),
|
| 95 |
+
slice(-shift_size_w, None))
|
| 96 |
+
cnt = 0
|
| 97 |
+
for h in h_slices:
|
| 98 |
+
for w in w_slices:
|
| 99 |
+
img_mask[:, h, w, :] = cnt
|
| 100 |
+
cnt += 1
|
| 101 |
+
|
| 102 |
+
mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True)
|
| 103 |
+
|
| 104 |
+
mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
|
| 105 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 106 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 107 |
+
|
| 108 |
+
return attn_mask
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def feature_add_position(feature0, feature1, attn_splits, feature_channels):
|
| 112 |
+
pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
|
| 113 |
+
|
| 114 |
+
if attn_splits > 1: # add position in splited window
|
| 115 |
+
feature0_splits = split_feature(feature0, num_splits=attn_splits)
|
| 116 |
+
feature1_splits = split_feature(feature1, num_splits=attn_splits)
|
| 117 |
+
|
| 118 |
+
position = pos_enc(feature0_splits)
|
| 119 |
+
|
| 120 |
+
feature0_splits = feature0_splits + position
|
| 121 |
+
feature1_splits = feature1_splits + position
|
| 122 |
+
|
| 123 |
+
feature0 = merge_splits(feature0_splits, num_splits=attn_splits)
|
| 124 |
+
feature1 = merge_splits(feature1_splits, num_splits=attn_splits)
|
| 125 |
+
else:
|
| 126 |
+
position = pos_enc(feature0)
|
| 127 |
+
|
| 128 |
+
feature0 = feature0 + position
|
| 129 |
+
feature1 = feature1 + position
|
| 130 |
+
|
| 131 |
+
return feature0, feature1
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def upsample_flow_with_mask(flow, up_mask, upsample_factor,
|
| 135 |
+
is_depth=False):
|
| 136 |
+
# convex upsampling following raft
|
| 137 |
+
|
| 138 |
+
mask = up_mask
|
| 139 |
+
b, flow_channel, h, w = flow.shape
|
| 140 |
+
mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w) # [B, 1, 9, K, K, H, W]
|
| 141 |
+
mask = torch.softmax(mask, dim=2)
|
| 142 |
+
|
| 143 |
+
multiplier = 1 if is_depth else upsample_factor
|
| 144 |
+
up_flow = F.unfold(multiplier * flow, [3, 3], padding=1)
|
| 145 |
+
up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W]
|
| 146 |
+
|
| 147 |
+
up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W]
|
| 148 |
+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W]
|
| 149 |
+
up_flow = up_flow.reshape(b, flow_channel, upsample_factor * h,
|
| 150 |
+
upsample_factor * w) # [B, 2, K*H, K*W]
|
| 151 |
+
|
| 152 |
+
return up_flow
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def split_feature_1d(feature,
|
| 156 |
+
num_splits=2,
|
| 157 |
+
):
|
| 158 |
+
# feature: [B, W, C]
|
| 159 |
+
b, w, c = feature.size()
|
| 160 |
+
assert w % num_splits == 0
|
| 161 |
+
|
| 162 |
+
b_new = b * num_splits
|
| 163 |
+
w_new = w // num_splits
|
| 164 |
+
|
| 165 |
+
feature = feature.view(b, num_splits, w // num_splits, c
|
| 166 |
+
).view(b_new, w_new, c) # [B*K, W/K, C]
|
| 167 |
+
|
| 168 |
+
return feature
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def merge_splits_1d(splits,
|
| 172 |
+
h,
|
| 173 |
+
num_splits=2,
|
| 174 |
+
):
|
| 175 |
+
b, w, c = splits.size()
|
| 176 |
+
new_b = b // num_splits // h
|
| 177 |
+
|
| 178 |
+
splits = splits.view(new_b, h, num_splits, w, c)
|
| 179 |
+
merge = splits.view(
|
| 180 |
+
new_b, h, num_splits * w, c) # [B, H, W, C]
|
| 181 |
+
|
| 182 |
+
return merge
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def window_partition_1d(x, window_size_w):
|
| 186 |
+
"""
|
| 187 |
+
Args:
|
| 188 |
+
x: (B, W, C)
|
| 189 |
+
window_size (int): window size
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
windows: (num_windows*B, window_size, C)
|
| 193 |
+
"""
|
| 194 |
+
B, W, C = x.shape
|
| 195 |
+
x = x.view(B, W // window_size_w, window_size_w, C).view(-1, window_size_w, C)
|
| 196 |
+
return x
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def generate_shift_window_attn_mask_1d(input_w, window_size_w,
|
| 200 |
+
shift_size_w, device=torch.device('cuda')):
|
| 201 |
+
# calculate attention mask for SW-MSA
|
| 202 |
+
img_mask = torch.zeros((1, input_w, 1)).to(device) # 1 W 1
|
| 203 |
+
w_slices = (slice(0, -window_size_w),
|
| 204 |
+
slice(-window_size_w, -shift_size_w),
|
| 205 |
+
slice(-shift_size_w, None))
|
| 206 |
+
cnt = 0
|
| 207 |
+
for w in w_slices:
|
| 208 |
+
img_mask[:, w, :] = cnt
|
| 209 |
+
cnt += 1
|
| 210 |
+
|
| 211 |
+
mask_windows = window_partition_1d(img_mask, window_size_w) # nW, window_size, 1
|
| 212 |
+
mask_windows = mask_windows.view(-1, window_size_w)
|
| 213 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW, window_size, window_size
|
| 214 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 215 |
+
|
| 216 |
+
return attn_mask
|
utils/flow_viz.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2018 Tom Runia
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to conditions.
|
| 11 |
+
#
|
| 12 |
+
# Author: Tom Runia
|
| 13 |
+
# Date Created: 2018-08-03
|
| 14 |
+
|
| 15 |
+
from __future__ import absolute_import
|
| 16 |
+
from __future__ import division
|
| 17 |
+
from __future__ import print_function
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
from PIL import Image
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def make_colorwheel():
|
| 24 |
+
'''
|
| 25 |
+
Generates a color wheel for optical flow visualization as presented in:
|
| 26 |
+
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
| 27 |
+
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
| 28 |
+
According to the C++ source code of Daniel Scharstein
|
| 29 |
+
According to the Matlab source code of Deqing Sun
|
| 30 |
+
'''
|
| 31 |
+
|
| 32 |
+
RY = 15
|
| 33 |
+
YG = 6
|
| 34 |
+
GC = 4
|
| 35 |
+
CB = 11
|
| 36 |
+
BM = 13
|
| 37 |
+
MR = 6
|
| 38 |
+
|
| 39 |
+
ncols = RY + YG + GC + CB + BM + MR
|
| 40 |
+
colorwheel = np.zeros((ncols, 3))
|
| 41 |
+
col = 0
|
| 42 |
+
|
| 43 |
+
# RY
|
| 44 |
+
colorwheel[0:RY, 0] = 255
|
| 45 |
+
colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
|
| 46 |
+
col = col + RY
|
| 47 |
+
# YG
|
| 48 |
+
colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
|
| 49 |
+
colorwheel[col:col + YG, 1] = 255
|
| 50 |
+
col = col + YG
|
| 51 |
+
# GC
|
| 52 |
+
colorwheel[col:col + GC, 1] = 255
|
| 53 |
+
colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
|
| 54 |
+
col = col + GC
|
| 55 |
+
# CB
|
| 56 |
+
colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
|
| 57 |
+
colorwheel[col:col + CB, 2] = 255
|
| 58 |
+
col = col + CB
|
| 59 |
+
# BM
|
| 60 |
+
colorwheel[col:col + BM, 2] = 255
|
| 61 |
+
colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
|
| 62 |
+
col = col + BM
|
| 63 |
+
# MR
|
| 64 |
+
colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
|
| 65 |
+
colorwheel[col:col + MR, 0] = 255
|
| 66 |
+
return colorwheel
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def flow_compute_color(u, v, convert_to_bgr=False):
|
| 70 |
+
'''
|
| 71 |
+
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
| 72 |
+
According to the C++ source code of Daniel Scharstein
|
| 73 |
+
According to the Matlab source code of Deqing Sun
|
| 74 |
+
:param u: np.ndarray, input horizontal flow
|
| 75 |
+
:param v: np.ndarray, input vertical flow
|
| 76 |
+
:param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB
|
| 77 |
+
:return:
|
| 78 |
+
'''
|
| 79 |
+
|
| 80 |
+
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
| 81 |
+
|
| 82 |
+
colorwheel = make_colorwheel() # shape [55x3]
|
| 83 |
+
ncols = colorwheel.shape[0]
|
| 84 |
+
|
| 85 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
| 86 |
+
a = np.arctan2(-v, -u) / np.pi
|
| 87 |
+
|
| 88 |
+
fk = (a + 1) / 2 * (ncols - 1) + 1
|
| 89 |
+
k0 = np.floor(fk).astype(np.int32)
|
| 90 |
+
k1 = k0 + 1
|
| 91 |
+
k1[k1 == ncols] = 1
|
| 92 |
+
f = fk - k0
|
| 93 |
+
|
| 94 |
+
for i in range(colorwheel.shape[1]):
|
| 95 |
+
tmp = colorwheel[:, i]
|
| 96 |
+
col0 = tmp[k0] / 255.0
|
| 97 |
+
col1 = tmp[k1] / 255.0
|
| 98 |
+
col = (1 - f) * col0 + f * col1
|
| 99 |
+
|
| 100 |
+
idx = (rad <= 1)
|
| 101 |
+
col[idx] = 1 - rad[idx] * (1 - col[idx])
|
| 102 |
+
col[~idx] = col[~idx] * 0.75 # out of range?
|
| 103 |
+
|
| 104 |
+
# Note the 2-i => BGR instead of RGB
|
| 105 |
+
ch_idx = 2 - i if convert_to_bgr else i
|
| 106 |
+
flow_image[:, :, ch_idx] = np.floor(255 * col)
|
| 107 |
+
|
| 108 |
+
return flow_image
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
|
| 112 |
+
'''
|
| 113 |
+
Expects a two dimensional flow image of shape [H,W,2]
|
| 114 |
+
According to the C++ source code of Daniel Scharstein
|
| 115 |
+
According to the Matlab source code of Deqing Sun
|
| 116 |
+
:param flow_uv: np.ndarray of shape [H,W,2]
|
| 117 |
+
:param clip_flow: float, maximum clipping value for flow
|
| 118 |
+
:return:
|
| 119 |
+
'''
|
| 120 |
+
|
| 121 |
+
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
| 122 |
+
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
| 123 |
+
|
| 124 |
+
if clip_flow is not None:
|
| 125 |
+
flow_uv = np.clip(flow_uv, 0, clip_flow)
|
| 126 |
+
|
| 127 |
+
u = flow_uv[:, :, 0]
|
| 128 |
+
v = flow_uv[:, :, 1]
|
| 129 |
+
|
| 130 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
| 131 |
+
rad_max = np.max(rad)
|
| 132 |
+
|
| 133 |
+
epsilon = 1e-5
|
| 134 |
+
u = u / (rad_max + epsilon)
|
| 135 |
+
v = v / (rad_max + epsilon)
|
| 136 |
+
|
| 137 |
+
return flow_compute_color(u, v, convert_to_bgr)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
UNKNOWN_FLOW_THRESH = 1e7
|
| 141 |
+
SMALLFLOW = 0.0
|
| 142 |
+
LARGEFLOW = 1e8
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def make_color_wheel():
|
| 146 |
+
"""
|
| 147 |
+
Generate color wheel according Middlebury color code
|
| 148 |
+
:return: Color wheel
|
| 149 |
+
"""
|
| 150 |
+
RY = 15
|
| 151 |
+
YG = 6
|
| 152 |
+
GC = 4
|
| 153 |
+
CB = 11
|
| 154 |
+
BM = 13
|
| 155 |
+
MR = 6
|
| 156 |
+
|
| 157 |
+
ncols = RY + YG + GC + CB + BM + MR
|
| 158 |
+
|
| 159 |
+
colorwheel = np.zeros([ncols, 3])
|
| 160 |
+
|
| 161 |
+
col = 0
|
| 162 |
+
|
| 163 |
+
# RY
|
| 164 |
+
colorwheel[0:RY, 0] = 255
|
| 165 |
+
colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))
|
| 166 |
+
col += RY
|
| 167 |
+
|
| 168 |
+
# YG
|
| 169 |
+
colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG))
|
| 170 |
+
colorwheel[col:col + YG, 1] = 255
|
| 171 |
+
col += YG
|
| 172 |
+
|
| 173 |
+
# GC
|
| 174 |
+
colorwheel[col:col + GC, 1] = 255
|
| 175 |
+
colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))
|
| 176 |
+
col += GC
|
| 177 |
+
|
| 178 |
+
# CB
|
| 179 |
+
colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB))
|
| 180 |
+
colorwheel[col:col + CB, 2] = 255
|
| 181 |
+
col += CB
|
| 182 |
+
|
| 183 |
+
# BM
|
| 184 |
+
colorwheel[col:col + BM, 2] = 255
|
| 185 |
+
colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))
|
| 186 |
+
col += + BM
|
| 187 |
+
|
| 188 |
+
# MR
|
| 189 |
+
colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
|
| 190 |
+
colorwheel[col:col + MR, 0] = 255
|
| 191 |
+
|
| 192 |
+
return colorwheel
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def compute_color(u, v):
|
| 196 |
+
"""
|
| 197 |
+
compute optical flow color map
|
| 198 |
+
:param u: optical flow horizontal map
|
| 199 |
+
:param v: optical flow vertical map
|
| 200 |
+
:return: optical flow in color code
|
| 201 |
+
"""
|
| 202 |
+
[h, w] = u.shape
|
| 203 |
+
img = np.zeros([h, w, 3])
|
| 204 |
+
nanIdx = np.isnan(u) | np.isnan(v)
|
| 205 |
+
u[nanIdx] = 0
|
| 206 |
+
v[nanIdx] = 0
|
| 207 |
+
|
| 208 |
+
colorwheel = make_color_wheel()
|
| 209 |
+
ncols = np.size(colorwheel, 0)
|
| 210 |
+
|
| 211 |
+
rad = np.sqrt(u ** 2 + v ** 2)
|
| 212 |
+
|
| 213 |
+
a = np.arctan2(-v, -u) / np.pi
|
| 214 |
+
|
| 215 |
+
fk = (a + 1) / 2 * (ncols - 1) + 1
|
| 216 |
+
|
| 217 |
+
k0 = np.floor(fk).astype(int)
|
| 218 |
+
|
| 219 |
+
k1 = k0 + 1
|
| 220 |
+
k1[k1 == ncols + 1] = 1
|
| 221 |
+
f = fk - k0
|
| 222 |
+
|
| 223 |
+
for i in range(0, np.size(colorwheel, 1)):
|
| 224 |
+
tmp = colorwheel[:, i]
|
| 225 |
+
col0 = tmp[k0 - 1] / 255
|
| 226 |
+
col1 = tmp[k1 - 1] / 255
|
| 227 |
+
col = (1 - f) * col0 + f * col1
|
| 228 |
+
|
| 229 |
+
idx = rad <= 1
|
| 230 |
+
col[idx] = 1 - rad[idx] * (1 - col[idx])
|
| 231 |
+
notidx = np.logical_not(idx)
|
| 232 |
+
|
| 233 |
+
col[notidx] *= 0.75
|
| 234 |
+
img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx)))
|
| 235 |
+
|
| 236 |
+
return img
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# from https://github.com/gengshan-y/VCN
|
| 240 |
+
def flow_to_image(flow):
|
| 241 |
+
"""
|
| 242 |
+
Convert flow into middlebury color code image
|
| 243 |
+
:param flow: optical flow map
|
| 244 |
+
:return: optical flow image in middlebury color
|
| 245 |
+
"""
|
| 246 |
+
u = flow[:, :, 0]
|
| 247 |
+
v = flow[:, :, 1]
|
| 248 |
+
|
| 249 |
+
maxu = -999.
|
| 250 |
+
maxv = -999.
|
| 251 |
+
minu = 999.
|
| 252 |
+
minv = 999.
|
| 253 |
+
|
| 254 |
+
idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
|
| 255 |
+
u[idxUnknow] = 0
|
| 256 |
+
v[idxUnknow] = 0
|
| 257 |
+
|
| 258 |
+
maxu = max(maxu, np.max(u))
|
| 259 |
+
minu = min(minu, np.min(u))
|
| 260 |
+
|
| 261 |
+
maxv = max(maxv, np.max(v))
|
| 262 |
+
minv = min(minv, np.min(v))
|
| 263 |
+
|
| 264 |
+
rad = np.sqrt(u ** 2 + v ** 2)
|
| 265 |
+
maxrad = max(-1, np.max(rad))
|
| 266 |
+
|
| 267 |
+
u = u / (maxrad + np.finfo(float).eps)
|
| 268 |
+
v = v / (maxrad + np.finfo(float).eps)
|
| 269 |
+
|
| 270 |
+
img = compute_color(u, v)
|
| 271 |
+
|
| 272 |
+
idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
|
| 273 |
+
img[idx] = 0
|
| 274 |
+
|
| 275 |
+
return np.uint8(img)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def save_vis_flow_tofile(flow, output_path):
|
| 279 |
+
vis_flow = flow_to_image(flow)
|
| 280 |
+
Image.fromarray(vis_flow).save(output_path)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def flow_tensor_to_image(flow):
|
| 284 |
+
"""Used for tensorboard visualization"""
|
| 285 |
+
flow = flow.permute(1, 2, 0) # [H, W, 2]
|
| 286 |
+
flow = flow.detach().cpu().numpy()
|
| 287 |
+
flow = flow_to_image(flow) # [H, W, 3]
|
| 288 |
+
flow = np.transpose(flow, (2, 0, 1)) # [3, H, W]
|
| 289 |
+
|
| 290 |
+
return flow
|
utils/visualization.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.utils.data
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torchvision.utils as vutils
|
| 5 |
+
import cv2
|
| 6 |
+
from matplotlib.cm import get_cmap
|
| 7 |
+
import matplotlib as mpl
|
| 8 |
+
import matplotlib.cm as cm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def vis_disparity(disp, return_rgb=False):
|
| 12 |
+
disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0
|
| 13 |
+
disp_vis = disp_vis.astype("uint8")
|
| 14 |
+
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
|
| 15 |
+
|
| 16 |
+
if return_rgb:
|
| 17 |
+
disp_vis = cv2.cvtColor(disp_vis, cv2.COLOR_BGR2RGB)
|
| 18 |
+
|
| 19 |
+
return disp_vis
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def gen_error_colormap():
|
| 23 |
+
cols = np.array(
|
| 24 |
+
[[0 / 3.0, 0.1875 / 3.0, 49, 54, 149],
|
| 25 |
+
[0.1875 / 3.0, 0.375 / 3.0, 69, 117, 180],
|
| 26 |
+
[0.375 / 3.0, 0.75 / 3.0, 116, 173, 209],
|
| 27 |
+
[0.75 / 3.0, 1.5 / 3.0, 171, 217, 233],
|
| 28 |
+
[1.5 / 3.0, 3 / 3.0, 224, 243, 248],
|
| 29 |
+
[3 / 3.0, 6 / 3.0, 254, 224, 144],
|
| 30 |
+
[6 / 3.0, 12 / 3.0, 253, 174, 97],
|
| 31 |
+
[12 / 3.0, 24 / 3.0, 244, 109, 67],
|
| 32 |
+
[24 / 3.0, 48 / 3.0, 215, 48, 39],
|
| 33 |
+
[48 / 3.0, np.inf, 165, 0, 38]], dtype=np.float32)
|
| 34 |
+
cols[:, 2: 5] /= 255.
|
| 35 |
+
return cols
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def disp_error_img(D_est_tensor, D_gt_tensor, abs_thres=3., rel_thres=0.05, dilate_radius=1):
|
| 39 |
+
D_gt_np = D_gt_tensor.detach().cpu().numpy()
|
| 40 |
+
D_est_np = D_est_tensor.detach().cpu().numpy()
|
| 41 |
+
B, H, W = D_gt_np.shape
|
| 42 |
+
# valid mask
|
| 43 |
+
mask = D_gt_np > 0
|
| 44 |
+
# error in percentage. When error <= 1, the pixel is valid since <= 3px & 5%
|
| 45 |
+
error = np.abs(D_gt_np - D_est_np)
|
| 46 |
+
error[np.logical_not(mask)] = 0
|
| 47 |
+
error[mask] = np.minimum(error[mask] / abs_thres, (error[mask] / D_gt_np[mask]) / rel_thres)
|
| 48 |
+
# get colormap
|
| 49 |
+
cols = gen_error_colormap()
|
| 50 |
+
# create error image
|
| 51 |
+
error_image = np.zeros([B, H, W, 3], dtype=np.float32)
|
| 52 |
+
for i in range(cols.shape[0]):
|
| 53 |
+
error_image[np.logical_and(error >= cols[i][0], error < cols[i][1])] = cols[i, 2:]
|
| 54 |
+
# TODO: imdilate
|
| 55 |
+
# error_image = cv2.imdilate(D_err, strel('disk', dilate_radius));
|
| 56 |
+
error_image[np.logical_not(mask)] = 0.
|
| 57 |
+
# show color tag in the top-left cornor of the image
|
| 58 |
+
for i in range(cols.shape[0]):
|
| 59 |
+
distance = 20
|
| 60 |
+
error_image[:, :10, i * distance:(i + 1) * distance, :] = cols[i, 2:]
|
| 61 |
+
|
| 62 |
+
return torch.from_numpy(np.ascontiguousarray(error_image.transpose([0, 3, 1, 2])))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def save_images(logger, mode_tag, images_dict, global_step):
|
| 66 |
+
images_dict = tensor2numpy(images_dict)
|
| 67 |
+
for tag, values in images_dict.items():
|
| 68 |
+
if not isinstance(values, list) and not isinstance(values, tuple):
|
| 69 |
+
values = [values]
|
| 70 |
+
for idx, value in enumerate(values):
|
| 71 |
+
if len(value.shape) == 3:
|
| 72 |
+
value = value[:, np.newaxis, :, :]
|
| 73 |
+
value = value[:1]
|
| 74 |
+
value = torch.from_numpy(value)
|
| 75 |
+
|
| 76 |
+
image_name = '{}/{}'.format(mode_tag, tag)
|
| 77 |
+
if len(values) > 1:
|
| 78 |
+
image_name = image_name + "_" + str(idx)
|
| 79 |
+
logger.add_image(image_name, vutils.make_grid(value, padding=0, nrow=1, normalize=True, scale_each=True),
|
| 80 |
+
global_step)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def tensor2numpy(var_dict):
|
| 84 |
+
for key, vars in var_dict.items():
|
| 85 |
+
if isinstance(vars, np.ndarray):
|
| 86 |
+
var_dict[key] = vars
|
| 87 |
+
elif isinstance(vars, torch.Tensor):
|
| 88 |
+
var_dict[key] = vars.data.cpu().numpy()
|
| 89 |
+
else:
|
| 90 |
+
raise NotImplementedError("invalid input type for tensor2numpy")
|
| 91 |
+
|
| 92 |
+
return var_dict
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def viz_depth_tensor_from_monodepth2(disp, return_numpy=False, colormap='plasma'):
|
| 96 |
+
# visualize inverse depth
|
| 97 |
+
assert isinstance(disp, torch.Tensor)
|
| 98 |
+
|
| 99 |
+
disp = disp.numpy()
|
| 100 |
+
vmax = np.percentile(disp, 95)
|
| 101 |
+
normalizer = mpl.colors.Normalize(vmin=disp.min(), vmax=vmax)
|
| 102 |
+
mapper = cm.ScalarMappable(norm=normalizer, cmap=colormap)
|
| 103 |
+
colormapped_im = (mapper.to_rgba(disp)[:, :, :3] * 255).astype(np.uint8) # [H, W, 3]
|
| 104 |
+
|
| 105 |
+
if return_numpy:
|
| 106 |
+
return colormapped_im
|
| 107 |
+
|
| 108 |
+
viz = torch.from_numpy(colormapped_im).permute(2, 0, 1) # [3, H, W]
|
| 109 |
+
|
| 110 |
+
return viz
|