Spaces:
Sleeping
Sleeping
code release
Browse filesSigned-off-by: tingmany <tmyann@outlook.com>
- .gitattributes +1 -1
- README.md +1 -1
- dataloader/stereo/transforms.py +82 -0
- gradio_app.py +371 -0
- models/__init__.py +0 -0
- models/attention_blocks.py +210 -0
- models/common.py +48 -0
- models/compile.sh +4 -0
- models/convformer.py +391 -0
- models/cost_volume.py +179 -0
- models/mat_pytorch_impl.py +178 -0
- models/match_former_ops.py +121 -0
- models/match_stereo.py +130 -0
- models/setup.py +20 -0
- models/src/match_former_cuda.cpp +49 -0
- models/src/match_former_cuda_kernel.cu +26 -0
- models/src/match_former_fused_forward.cu +628 -0
- models/src/match_former_fused_forward.hpp +22 -0
- requirements.txt +14 -0
- utils/file_io.py +37 -0
- utils/utils.py +58 -0
.gitattributes
CHANGED
|
@@ -33,4 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
-
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -5,7 +5,7 @@ colorFrom: green
|
|
| 5 |
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.49.1
|
| 8 |
-
app_file:
|
| 9 |
pinned: false
|
| 10 |
license: gpl-3.0
|
| 11 |
---
|
|
|
|
| 5 |
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.49.1
|
| 8 |
+
app_file: gradio_app.py
|
| 9 |
pinned: false
|
| 10 |
license: gpl-3.0
|
| 11 |
---
|
dataloader/stereo/transforms.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import division
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Compose(object):
|
| 8 |
+
def __init__(self, transforms):
|
| 9 |
+
self.transforms = transforms
|
| 10 |
+
|
| 11 |
+
def __call__(self, sample):
|
| 12 |
+
for t in self.transforms:
|
| 13 |
+
sample = t(sample)
|
| 14 |
+
return sample
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ToTensor(object):
|
| 18 |
+
"""Convert numpy array to torch tensor"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, no_normalize=False):
|
| 21 |
+
self.no_normalize = no_normalize
|
| 22 |
+
|
| 23 |
+
def __call__(self, sample):
|
| 24 |
+
left = np.transpose(sample['left'], (2, 0, 1)) # [3, H, W]
|
| 25 |
+
if self.no_normalize:
|
| 26 |
+
sample['left'] = torch.from_numpy(left)
|
| 27 |
+
else:
|
| 28 |
+
sample['left'] = torch.from_numpy(left) / 255.
|
| 29 |
+
right = np.transpose(sample['right'], (2, 0, 1))
|
| 30 |
+
|
| 31 |
+
if self.no_normalize:
|
| 32 |
+
sample['right'] = torch.from_numpy(right)
|
| 33 |
+
else:
|
| 34 |
+
sample['right'] = torch.from_numpy(right) / 255.
|
| 35 |
+
|
| 36 |
+
if 'disp' in sample.keys():
|
| 37 |
+
disp = sample['disp'] # [H, W]
|
| 38 |
+
sample['disp'] = torch.from_numpy(disp)
|
| 39 |
+
if 'disp_r' in sample.keys():
|
| 40 |
+
disp_r = sample['disp_r'] # [H, W]
|
| 41 |
+
sample['disp_r'] = torch.from_numpy(disp_r)
|
| 42 |
+
|
| 43 |
+
if 'valid' in sample.keys():
|
| 44 |
+
valid = sample['valid'] # [H, W]
|
| 45 |
+
sample['valid'] = torch.from_numpy(valid)
|
| 46 |
+
|
| 47 |
+
return sample
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Resize(object):
|
| 51 |
+
def __init__(self,
|
| 52 |
+
scale_x=1,
|
| 53 |
+
scale_y=1,
|
| 54 |
+
nearest_interp=True, # for sparse gt
|
| 55 |
+
):
|
| 56 |
+
"""
|
| 57 |
+
Resize low-resolution data to high-res for mixed dataset training
|
| 58 |
+
"""
|
| 59 |
+
self.scale_x = scale_x
|
| 60 |
+
self.scale_y = scale_y
|
| 61 |
+
self.nearest_interp = nearest_interp
|
| 62 |
+
|
| 63 |
+
def __call__(self, sample):
|
| 64 |
+
scale_x = self.scale_x
|
| 65 |
+
scale_y = self.scale_y
|
| 66 |
+
|
| 67 |
+
sample['left'] = cv2.resize(sample['left'], None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
| 68 |
+
sample['right'] = cv2.resize(sample['right'], None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
| 69 |
+
|
| 70 |
+
if 'disp' in sample.keys():
|
| 71 |
+
sample['disp'] = cv2.resize(
|
| 72 |
+
sample['disp'], None, fx=scale_x, fy=scale_y,
|
| 73 |
+
interpolation=cv2.INTER_LINEAR if not self.nearest_interp else cv2.INTER_NEAREST
|
| 74 |
+
) * scale_x
|
| 75 |
+
|
| 76 |
+
if 'disp_r' in sample.keys():
|
| 77 |
+
sample['disp_r'] = cv2.resize(
|
| 78 |
+
sample['disp_r'], None, fx=scale_x, fy=scale_y,
|
| 79 |
+
interpolation=cv2.INTER_LINEAR if not self.nearest_interp else cv2.INTER_NEAREST
|
| 80 |
+
) * scale_x
|
| 81 |
+
|
| 82 |
+
return sample
|
gradio_app.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import argparse
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
import spaces
|
| 9 |
+
|
| 10 |
+
from dataloader.stereo import transforms
|
| 11 |
+
from utils.utils import InputPadder, calc_noc_mask
|
| 12 |
+
from huggingface_hub import hf_hub_download
|
| 13 |
+
from models.match_stereo import MatchStereo
|
| 14 |
+
|
| 15 |
+
torch.backends.cudnn.benchmark = True
|
| 16 |
+
|
| 17 |
+
class MatchStereoDemo:
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.has_cuda = torch.cuda.is_available()
|
| 20 |
+
self.device = torch.device('cuda:0') if self.has_cuda else 'cpu'
|
| 21 |
+
self.model = None
|
| 22 |
+
self.current_variant = None
|
| 23 |
+
self.current_mode = None
|
| 24 |
+
self.current_precision = None
|
| 25 |
+
self.current_mat_impl = None
|
| 26 |
+
self.download_model()
|
| 27 |
+
|
| 28 |
+
def download_model(self):
|
| 29 |
+
REPO_ID = 'Tingman/MatchAttention'
|
| 30 |
+
filename_list = ['matchstereo_tiny_fsd.pth', 'matchstereo_small_fsd.pth', 'matchstereo_base_fsd.pth', 'matchflow_base_sintel.pth']
|
| 31 |
+
if not os.path.exists('./checkpoints/'):
|
| 32 |
+
os.makedirs('./checkpoints/')
|
| 33 |
+
for filename in filename_list:
|
| 34 |
+
local_file = os.path.join('./checkpoints/', filename)
|
| 35 |
+
if not os.path.exists(local_file):
|
| 36 |
+
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/', local_dir_use_symlinks=False)
|
| 37 |
+
|
| 38 |
+
def load_model(self, mode, variant, precision, mat_impl):
|
| 39 |
+
"""load model, skip if the model has been loaded"""
|
| 40 |
+
if (self.model is not None and
|
| 41 |
+
self.current_variant == variant and
|
| 42 |
+
self.current_mode == mode and
|
| 43 |
+
self.current_precision == precision and
|
| 44 |
+
self.current_mat_impl == mat_impl):
|
| 45 |
+
return "Model already loaded"
|
| 46 |
+
|
| 47 |
+
# fixed checkpoint path
|
| 48 |
+
checkpoint_base_path = "./checkpoints"
|
| 49 |
+
if mode == 'stereo':
|
| 50 |
+
checkpoint_name = f"match{mode}_{variant}_fsd.pth"
|
| 51 |
+
elif mode == 'flow':
|
| 52 |
+
checkpoint_name = f"match{mode}_{variant}_sintel.pth"
|
| 53 |
+
else:
|
| 54 |
+
raise NotImplementedError
|
| 55 |
+
|
| 56 |
+
checkpoint_path = os.path.join(checkpoint_base_path, checkpoint_name)
|
| 57 |
+
|
| 58 |
+
if not os.path.exists(checkpoint_path):
|
| 59 |
+
return f"Error: Checkpoint not found at {checkpoint_path}"
|
| 60 |
+
|
| 61 |
+
args = argparse.Namespace()
|
| 62 |
+
args.mode = mode
|
| 63 |
+
args.variant = variant
|
| 64 |
+
args.mat_impl = mat_impl
|
| 65 |
+
|
| 66 |
+
if not self.has_cuda:
|
| 67 |
+
precision = "fp32"
|
| 68 |
+
dtypes = {'fp32': torch.float32, 'fp16': torch.float16}
|
| 69 |
+
self.dtype = dtypes[precision]
|
| 70 |
+
|
| 71 |
+
self.model = MatchStereo(args)
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 75 |
+
self.model.load_state_dict(state_dict=checkpoint['model'], strict=False)
|
| 76 |
+
self.model.to(self.device)
|
| 77 |
+
self.model.eval()
|
| 78 |
+
self.model = self.model.to(self.dtype)
|
| 79 |
+
|
| 80 |
+
self._warmup_model()
|
| 81 |
+
|
| 82 |
+
self.current_variant = variant
|
| 83 |
+
self.current_mode = mode
|
| 84 |
+
self.current_precision = precision
|
| 85 |
+
self.current_mat_impl = mat_impl
|
| 86 |
+
|
| 87 |
+
device_info = "GPU" if self.has_cuda else "CPU"
|
| 88 |
+
return f"Successfully loaded {mode} {variant} model on {device_info} (precision: {precision}, mat_impl: {mat_impl})"
|
| 89 |
+
except Exception as e:
|
| 90 |
+
return f"Error loading model: {str(e)}"
|
| 91 |
+
|
| 92 |
+
def _warmup_model(self):
|
| 93 |
+
"""warmup the model for accurate time measurement"""
|
| 94 |
+
if self.model is None:
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
+
dummy_left = torch.randn(1, 3, 512, 512, device=self.device, dtype=self.dtype)
|
| 98 |
+
dummy_right = torch.randn(1, 3, 512, 512, device=self.device, dtype=self.dtype)
|
| 99 |
+
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
_ = self.model(dummy_left, dummy_right, stereo=(self.current_mode == 'stereo'))
|
| 102 |
+
|
| 103 |
+
def run_frame(self, left, right, stereo, low_res_init=False, factor=2.):
|
| 104 |
+
"""single frame inference"""
|
| 105 |
+
if low_res_init:
|
| 106 |
+
left_ds = F.interpolate(left, scale_factor=1/factor, mode='bilinear', align_corners=True)
|
| 107 |
+
right_ds = F.interpolate(right, scale_factor=1/factor, mode='bilinear', align_corners=True)
|
| 108 |
+
padder_ds = InputPadder(left_ds.shape, padding_factor=32)
|
| 109 |
+
left_ds, right_ds = padder_ds.pad(left_ds, right_ds)
|
| 110 |
+
|
| 111 |
+
field_up_ds = self.model(left_ds, right_ds, stereo=stereo)['field_up']
|
| 112 |
+
field_up_ds = padder_ds.unpad(field_up_ds.permute(0, 3, 1, 2).contiguous()).contiguous()
|
| 113 |
+
field_up_init = F.interpolate(field_up_ds, scale_factor=factor/32, mode='bilinear', align_corners=True)*(factor/32)
|
| 114 |
+
field_up_init = field_up_init.permute(0, 2, 3, 1).contiguous()
|
| 115 |
+
results_dict = self.model(left, right, stereo=stereo, init_flow=field_up_init)
|
| 116 |
+
else:
|
| 117 |
+
results_dict = self.model(left, right, stereo=stereo)
|
| 118 |
+
|
| 119 |
+
return results_dict
|
| 120 |
+
|
| 121 |
+
def get_inference_size(self, size_name):
|
| 122 |
+
if size_name == "Original":
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
def round_to_32(x):
|
| 126 |
+
return (x + 16) // 32 * 32
|
| 127 |
+
|
| 128 |
+
size_presets = {
|
| 129 |
+
"720P": (round_to_32(1280), round_to_32(720)),
|
| 130 |
+
"1080P": (round_to_32(1920), round_to_32(1080)),
|
| 131 |
+
"2K": (round_to_32(2048), round_to_32(1080)),
|
| 132 |
+
"4K UHD": (round_to_32(3840), round_to_32(2160))
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
return size_presets.get(size_name, None)
|
| 136 |
+
|
| 137 |
+
def process_images(self, left_image, right_image, mode, variant,
|
| 138 |
+
low_res_init=False, inference_size_name="Original",
|
| 139 |
+
precision="fp32", mat_impl="pytorch"):
|
| 140 |
+
if not self.has_cuda:
|
| 141 |
+
precision = "fp32"
|
| 142 |
+
mat_impl = "pytorch"
|
| 143 |
+
|
| 144 |
+
load_result = self.load_model(mode, variant, precision, mat_impl)
|
| 145 |
+
if load_result.startswith("Error"):
|
| 146 |
+
return None, None, None, load_result
|
| 147 |
+
|
| 148 |
+
try:
|
| 149 |
+
left = np.array(left_image.convert('RGB')).astype(np.float32)
|
| 150 |
+
right = np.array(right_image.convert('RGB')).astype(np.float32)
|
| 151 |
+
|
| 152 |
+
original_size = left.shape[:2] # (H, W)
|
| 153 |
+
|
| 154 |
+
inference_size = self.get_inference_size(inference_size_name)
|
| 155 |
+
|
| 156 |
+
val_transform_list = [transforms.ToTensor(no_normalize=True)]
|
| 157 |
+
val_transform = transforms.Compose(val_transform_list)
|
| 158 |
+
|
| 159 |
+
sample = {'left': left, 'right': right}
|
| 160 |
+
sample = val_transform(sample)
|
| 161 |
+
left_tensor = sample['left'].to(self.device, dtype=self.dtype).unsqueeze(0)
|
| 162 |
+
right_tensor = sample['right'].to(self.device, dtype=self.dtype).unsqueeze(0)
|
| 163 |
+
|
| 164 |
+
stereo = (mode == 'stereo')
|
| 165 |
+
|
| 166 |
+
ori_size = left_tensor.shape[-2:]
|
| 167 |
+
if inference_size is not None:
|
| 168 |
+
left_tensor = F.interpolate(left_tensor, size=inference_size, mode='bilinear', align_corners=True)
|
| 169 |
+
right_tensor = F.interpolate(right_tensor, size=inference_size, mode='bilinear', align_corners=True)
|
| 170 |
+
padder = None
|
| 171 |
+
else:
|
| 172 |
+
padder = InputPadder(left_tensor.shape, padding_factor=32)
|
| 173 |
+
left_tensor, right_tensor = padder.pad(left_tensor, right_tensor)
|
| 174 |
+
|
| 175 |
+
device_type = "GPU" if self.has_cuda else "CPU"
|
| 176 |
+
actual_size = inference_size if inference_size else ori_size
|
| 177 |
+
status_info = f"Device: {device_type} | Resolution: {actual_size[1]}x{actual_size[0]} | Precision: {precision}"
|
| 178 |
+
|
| 179 |
+
start_time = time.time()
|
| 180 |
+
with torch.no_grad():
|
| 181 |
+
results_dict = self.run_frame(left_tensor, right_tensor, stereo, low_res_init)
|
| 182 |
+
inference_time = (time.time() - start_time) * 1000 # ms
|
| 183 |
+
|
| 184 |
+
field_up = results_dict['field_up'].permute(0, 3, 1, 2).float().contiguous()
|
| 185 |
+
|
| 186 |
+
if padder is not None:
|
| 187 |
+
field_up = padder.unpad(field_up)
|
| 188 |
+
elif inference_size is not None:
|
| 189 |
+
field_up = F.interpolate(field_up, size=ori_size, mode='bilinear', align_corners=True)
|
| 190 |
+
field_up[:, 0] = field_up[:, 0] * (ori_size[1] / float(inference_size[1]))
|
| 191 |
+
field_up[:, 1] = field_up[:, 1] * (ori_size[0] / float(inference_size[0]))
|
| 192 |
+
|
| 193 |
+
noc_mask = calc_noc_mask(field_up.permute(0, 2, 3, 1), A=8)
|
| 194 |
+
noc_mask = noc_mask[0].detach().cpu().numpy()
|
| 195 |
+
noc_mask = np.where(noc_mask, 255, 128).astype(np.uint8)
|
| 196 |
+
|
| 197 |
+
field_up = torch.cat((field_up, torch.zeros_like(field_up[:, :1])), dim=1)
|
| 198 |
+
field_up = field_up.permute(0, 2, 3, 1).contiguous()
|
| 199 |
+
field, field_r = field_up.chunk(2, dim=0)
|
| 200 |
+
|
| 201 |
+
if stereo:
|
| 202 |
+
disparity = (-field[..., 0]).clamp(min=0)
|
| 203 |
+
|
| 204 |
+
disparity_np = disparity[0].detach().cpu().numpy()
|
| 205 |
+
min_val = disparity_np.min()
|
| 206 |
+
max_val = disparity_np.max()
|
| 207 |
+
if max_val - min_val > 1e-6:
|
| 208 |
+
disparity_norm = (disparity_np - min_val) / (max_val - min_val)
|
| 209 |
+
else:
|
| 210 |
+
disparity_norm = np.zeros_like(disparity_np)
|
| 211 |
+
disparity_img = (disparity_norm * 255).astype(np.uint8)
|
| 212 |
+
|
| 213 |
+
return disparity_img, noc_mask, f"Inference time: {inference_time:.2f} ms. (Please re-run to get accurate time.)", status_info
|
| 214 |
+
else:
|
| 215 |
+
flow = field[0].detach().cpu().numpy()
|
| 216 |
+
flow_rgb = self.flow_to_color(flow)
|
| 217 |
+
return flow_rgb, noc_mask, f"Inference time: {inference_time:.2f} ms. (Please re-run to get accurate time.)", status_info
|
| 218 |
+
|
| 219 |
+
except Exception as e:
|
| 220 |
+
device_type = "GPU" if self.has_cuda else "CPU"
|
| 221 |
+
return None, None, f"Error during inference: {str(e)}", f"Device: {device_type} | Error occurred"
|
| 222 |
+
|
| 223 |
+
def flow_to_color(self, flow):
|
| 224 |
+
"""visualization of flow"""
|
| 225 |
+
u = flow[..., 0]
|
| 226 |
+
v = flow[..., 1]
|
| 227 |
+
|
| 228 |
+
rad = np.sqrt(u**2 + v**2)
|
| 229 |
+
rad_max = np.max(rad)
|
| 230 |
+
epsilon = 1e-8
|
| 231 |
+
|
| 232 |
+
if rad_max > epsilon:
|
| 233 |
+
u = u / (rad_max + epsilon)
|
| 234 |
+
v = v / (rad_max + epsilon)
|
| 235 |
+
|
| 236 |
+
h, w = u.shape
|
| 237 |
+
hsv = np.zeros((h, w, 3), dtype=np.uint8)
|
| 238 |
+
hsv[..., 1] = 255
|
| 239 |
+
|
| 240 |
+
mag, ang = cv2.cartToPolar(u, v)
|
| 241 |
+
hsv[..., 0] = ang * 180 / np.pi / 2
|
| 242 |
+
hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
|
| 243 |
+
|
| 244 |
+
flow_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
|
| 245 |
+
return flow_rgb
|
| 246 |
+
|
| 247 |
+
demo_model = MatchStereoDemo()
|
| 248 |
+
|
| 249 |
+
# example images
|
| 250 |
+
examples = [
|
| 251 |
+
["examples/booster_bathroom_left.png", "examples/booster_bathroom_right.png", "stereo", "tiny"],
|
| 252 |
+
["examples/staircase_q_left.png", "examples/staircase_q_right.png", "stereo", "tiny"],
|
| 253 |
+
["examples/frame_0031_clean.png", "examples/frame_0032_clean.png", "flow", "base"],
|
| 254 |
+
]
|
| 255 |
+
|
| 256 |
+
@spaces.GPU
|
| 257 |
+
def process_inference(left_img, right_img, mode, variant,
|
| 258 |
+
low_res_init, inference_size, precision, mat_impl):
|
| 259 |
+
"""Gradio function"""
|
| 260 |
+
if left_img is None or right_img is None:
|
| 261 |
+
return None, None, "Please upload both left and right images", "Waiting for input..."
|
| 262 |
+
|
| 263 |
+
try:
|
| 264 |
+
result = demo_model.process_images(
|
| 265 |
+
left_img, right_img, mode, variant,
|
| 266 |
+
low_res_init, inference_size, precision, mat_impl
|
| 267 |
+
)
|
| 268 |
+
return result
|
| 269 |
+
except Exception as e:
|
| 270 |
+
return None, None, f"Error during inference: {str(e)}", f"Error: {str(e)}"
|
| 271 |
+
|
| 272 |
+
def update_variant_choices(mode):
|
| 273 |
+
if mode == "flow":
|
| 274 |
+
return gr.Radio(choices=["base"], value="base")
|
| 275 |
+
else:
|
| 276 |
+
return gr.Radio(choices=["tiny", "small", "base"], value="tiny")
|
| 277 |
+
|
| 278 |
+
# Gradio UI
|
| 279 |
+
with gr.Blocks(title="MatchStereo/MatchFlow Demo") as demo:
|
| 280 |
+
gr.Markdown("# MatchStereo/MatchFlow Demo")
|
| 281 |
+
gr.Markdown("Upload stereo images for disparity estimation or consecutive frames for optical flow estimation.")
|
| 282 |
+
|
| 283 |
+
if not demo_model.has_cuda:
|
| 284 |
+
gr.Markdown("> Note: Running on CPU. Some options (fp16, cuda) are disabled.")
|
| 285 |
+
|
| 286 |
+
with gr.Row():
|
| 287 |
+
with gr.Column():
|
| 288 |
+
left_image = gr.Image(label="Left Image / Frame 1", type="pil")
|
| 289 |
+
right_image = gr.Image(label="Right Image / Frame 2", type="pil")
|
| 290 |
+
|
| 291 |
+
with gr.Row():
|
| 292 |
+
mode = gr.Radio(
|
| 293 |
+
choices=["stereo", "flow"],
|
| 294 |
+
label="Mode",
|
| 295 |
+
value="stereo",
|
| 296 |
+
info="Select stereo for disparity estimation or flow for optical flow"
|
| 297 |
+
)
|
| 298 |
+
variant = gr.Radio(
|
| 299 |
+
choices=["tiny", "small", "base"],
|
| 300 |
+
label="Model Variant",
|
| 301 |
+
value="tiny",
|
| 302 |
+
info="Model size variant"
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
with gr.Row():
|
| 306 |
+
low_res_init = gr.Checkbox(
|
| 307 |
+
label="Low Resolution Init",
|
| 308 |
+
value=False,
|
| 309 |
+
info="Use low-resolution initialization for high-res images (>=2K)"
|
| 310 |
+
)
|
| 311 |
+
inference_size = gr.Dropdown(
|
| 312 |
+
choices=["Original", "720P", "1080P", "2K", "4K UHD"],
|
| 313 |
+
label="Inference Size",
|
| 314 |
+
value="Original",
|
| 315 |
+
info="Rounded to multiples of 32"
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
with gr.Row():
|
| 319 |
+
precision = gr.Radio(
|
| 320 |
+
choices=["fp32", "fp16"],
|
| 321 |
+
label="Precision",
|
| 322 |
+
value="fp32",
|
| 323 |
+
info="Model precision",
|
| 324 |
+
interactive=demo_model.has_cuda
|
| 325 |
+
)
|
| 326 |
+
mat_impl = gr.Radio(
|
| 327 |
+
choices=["cuda", "pytorch"],
|
| 328 |
+
label="MatchAttention Implementation",
|
| 329 |
+
value="cuda",
|
| 330 |
+
info="MatchAttention implementations",
|
| 331 |
+
interactive=demo_model.has_cuda
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
run_btn = gr.Button("Run Inference", variant="primary")
|
| 335 |
+
|
| 336 |
+
with gr.Column():
|
| 337 |
+
output_image = gr.Image(label="Output Result", interactive=False)
|
| 338 |
+
noc_mask = gr.Image(label="NOC Mask", interactive=False)
|
| 339 |
+
time_output = gr.Textbox(label="Inference Time", interactive=False)
|
| 340 |
+
status = gr.Textbox(label="Status Info", interactive=False, lines=2)
|
| 341 |
+
|
| 342 |
+
gr.Markdown("## Examples")
|
| 343 |
+
gr.Examples(
|
| 344 |
+
examples=examples,
|
| 345 |
+
inputs=[left_image, right_image, mode, variant],
|
| 346 |
+
outputs=[output_image, noc_mask, time_output, status],
|
| 347 |
+
fn=process_inference,
|
| 348 |
+
cache_examples=False,
|
| 349 |
+
label="Click any example below to load it"
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
run_btn.click(
|
| 353 |
+
fn=process_inference,
|
| 354 |
+
inputs=[left_image, right_image, mode, variant,
|
| 355 |
+
low_res_init, inference_size, precision, mat_impl],
|
| 356 |
+
outputs=[output_image, noc_mask, time_output, status]
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
mode.change(
|
| 360 |
+
fn=update_variant_choices,
|
| 361 |
+
inputs=[mode],
|
| 362 |
+
outputs=[variant]
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
if __name__ == "__main__":
|
| 366 |
+
try:
|
| 367 |
+
import cv2
|
| 368 |
+
except ImportError:
|
| 369 |
+
print("Please install OpenCV for optical flow visualization: pip install opencv-python")
|
| 370 |
+
|
| 371 |
+
demo.launch()
|
models/__init__.py
ADDED
|
File without changes
|
models/attention_blocks.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from timm.models.layers import DropPath
|
| 5 |
+
|
| 6 |
+
from models.convformer import LayerNormWithoutBias
|
| 7 |
+
from models.common import ConvGLU
|
| 8 |
+
from models.mat_pytorch_impl import compute_bilinear_weights, compute_match_attention, compute_bilinear_softmax, attention_aggregate
|
| 9 |
+
from models.match_former_ops import MF_FusedForwardOps
|
| 10 |
+
from utils.utils import bilinear_sample_by_offset, init_coords
|
| 11 |
+
|
| 12 |
+
class MatchAttention(torch.nn.Module):
|
| 13 |
+
r"""MatchAttention: Matching the relative positions
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self, args, dim, win_r=[1, 1], num_head=8, head_dim=None, qkv_bias=False,
|
| 16 |
+
attn_drop=0., proj_drop=0., proj_bias=False, cross=False, noc_embed=False, **kargs):
|
| 17 |
+
super().__init__()
|
| 18 |
+
|
| 19 |
+
self.num_head = num_head
|
| 20 |
+
self.cross = cross
|
| 21 |
+
self.noc_embed = noc_embed if not cross else False # only for self attention
|
| 22 |
+
|
| 23 |
+
self.head_dim = dim // num_head if head_dim is None else head_dim
|
| 24 |
+
self.scale = self.head_dim ** -0.5
|
| 25 |
+
|
| 26 |
+
self.attention_dim = self.num_head * self.head_dim
|
| 27 |
+
|
| 28 |
+
self.win_r = win_r
|
| 29 |
+
self.attn_num = (2*win_r[0]+2)*(2*win_r[1]+2)
|
| 30 |
+
|
| 31 |
+
embed_dim = dim + 1 if noc_embed else dim # '1' for noc_mask
|
| 32 |
+
self.q = nn.Linear(embed_dim, self.attention_dim, bias=qkv_bias)
|
| 33 |
+
self.k = nn.Linear(embed_dim, self.attention_dim, bias=qkv_bias)
|
| 34 |
+
self.v = nn.Linear(embed_dim, self.attention_dim, bias=qkv_bias)
|
| 35 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 36 |
+
if self.cross:
|
| 37 |
+
self.g = nn.Sequential(nn.Linear(embed_dim, self.attention_dim,bias=qkv_bias), nn.SiLU())
|
| 38 |
+
self.proj = nn.Linear(self.attention_dim + self.num_head*self.attn_num, dim, bias=proj_bias)
|
| 39 |
+
else:
|
| 40 |
+
self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias)
|
| 41 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 42 |
+
self.use_pytorch = (args.mat_impl == 'pytorch')
|
| 43 |
+
self.mf_fused = MF_FusedForwardOps()
|
| 44 |
+
|
| 45 |
+
def clamp_max_offset(self, max_offset, H, W):
|
| 46 |
+
max_offset_x, max_offset_y = max_offset.chunk(2, dim=-1) # to avoid inplace operation
|
| 47 |
+
|
| 48 |
+
# for ONNX support
|
| 49 |
+
min_x = torch.tensor(self.win_r[0], dtype=max_offset.dtype, device=max_offset.device)
|
| 50 |
+
max_x = torch.tensor(W - 1 - self.win_r[0] - 1e-3, dtype=max_offset.dtype, device=max_offset.device)
|
| 51 |
+
min_y = torch.tensor(self.win_r[1], dtype=max_offset.dtype, device=max_offset.device)
|
| 52 |
+
max_y = torch.tensor(H - 1 - self.win_r[1] - 1e-3, dtype=max_offset.dtype, device=max_offset.device)
|
| 53 |
+
|
| 54 |
+
max_offset_x = torch.clamp(max_offset_x, min=min_x, max=max_x)
|
| 55 |
+
max_offset_y = torch.clamp(max_offset_y, min=min_y, max=max_y)
|
| 56 |
+
|
| 57 |
+
## max_offset_x = max_offset_x.clamp(min=self.win_r[0], max=W-1-self.win_r[0]-1e-3)
|
| 58 |
+
## max_offset_y = max_offset_y.clamp(min=self.win_r[1], max=H-1-self.win_r[1]-1e-3)
|
| 59 |
+
return torch.cat((max_offset_x, max_offset_y), dim=-1).contiguous()
|
| 60 |
+
|
| 61 |
+
def forward(self, x, max_offset, noc_mask=None): # offset: [B, N, h, 2]
|
| 62 |
+
B, H, W, _ = x.shape
|
| 63 |
+
N = H*W
|
| 64 |
+
assert (2*self.win_r[1] + 2 <= H) and (2*self.win_r[0] + 2 <= W)
|
| 65 |
+
x = x.view(B, N, -1).contiguous()
|
| 66 |
+
|
| 67 |
+
if self.cross:
|
| 68 |
+
ref_, tgt_ = x.chunk(2, dim=0) # split along batch dimension
|
| 69 |
+
ref = torch.cat((ref_, tgt_), dim=0) # order
|
| 70 |
+
tgt = torch.cat((tgt_, ref_), dim=0) # reverse order
|
| 71 |
+
g = self.g(ref)
|
| 72 |
+
else: # self-attn
|
| 73 |
+
if self.noc_embed:
|
| 74 |
+
x = torch.cat((x, noc_mask.view(B, N, -1)), dim=-1).contiguous()
|
| 75 |
+
ref, tgt = x, x
|
| 76 |
+
q, k, v = self.q(ref), self.k(tgt), self.v(tgt)
|
| 77 |
+
|
| 78 |
+
## non-parameter modules
|
| 79 |
+
max_offset = self.clamp_max_offset(max_offset, H, W)
|
| 80 |
+
|
| 81 |
+
if self.use_pytorch:
|
| 82 |
+
m_id = torch.floor(max_offset).to(torch.int32) # [B, N, h, 2]
|
| 83 |
+
bilinear_weight = compute_bilinear_weights(max_offset)
|
| 84 |
+
|
| 85 |
+
attn, indices_gather = compute_match_attention(q.view(B, N, self.num_head, -1), k.view(B, N, self.num_head, -1), m_id, self.win_r, H, W)
|
| 86 |
+
attn = attn * self.scale
|
| 87 |
+
|
| 88 |
+
attn = compute_bilinear_softmax(attn, bilinear_weight, self.win_r)
|
| 89 |
+
attn = self.attn_drop(attn)
|
| 90 |
+
|
| 91 |
+
x = attention_aggregate(v.view(B, N, self.num_head, -1), attn, indices_gather, self.win_r)
|
| 92 |
+
else:
|
| 93 |
+
x, attn = self.mf_fused(max_offset, q, k, v, H, W, self.win_r, self.attn_num, attn_type='l1_norm', scale=self.scale)
|
| 94 |
+
|
| 95 |
+
if self.cross:
|
| 96 |
+
x = g * x # gate
|
| 97 |
+
attn = attn.view(B, N, -1).contiguous()
|
| 98 |
+
x = torch.cat((x, attn), dim=-1).contiguous()
|
| 99 |
+
x = self.proj(x)
|
| 100 |
+
x = self.proj_drop(x)
|
| 101 |
+
return x.view(B, H, W, -1).contiguous()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class MatchAttentionLayer(nn.Module):
|
| 105 |
+
r"""MatchAttention layer with interleaved self-MatchAttention, cross-MatchAttention, and ConvGLU
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(self, args, dim, win_r,
|
| 109 |
+
num_head=8, head_dim=32, mlp=ConvGLU, mlp_ratio=2, field_dim=2,
|
| 110 |
+
norm_layer=nn.LayerNorm, drop=0., drop_path=0.):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.num_head = num_head
|
| 113 |
+
self.field_dim = field_dim
|
| 114 |
+
|
| 115 |
+
self.match_attention_self = MatchAttention(args, dim + self.field_dim + self.num_head*2, [win_r, win_r], num_head=num_head, head_dim=head_dim, noc_embed=True)
|
| 116 |
+
self.norm0 = norm_layer(dim + self.field_dim + self.num_head*2)
|
| 117 |
+
|
| 118 |
+
self.match_attention_cross = MatchAttention(args, dim + self.field_dim, [win_r, win_r], num_head=num_head, head_dim=head_dim, cross=True)
|
| 119 |
+
self.norm1 = norm_layer(dim + self.field_dim)
|
| 120 |
+
|
| 121 |
+
self.mlp = mlp(dim=dim, mlp_ratio=mlp_ratio, drop=drop)
|
| 122 |
+
self.norm2 = norm_layer(dim)
|
| 123 |
+
|
| 124 |
+
self.field_scale = nn.Parameter(0.1*torch.ones(1, 1, 1, 2))
|
| 125 |
+
|
| 126 |
+
self.drop_path0 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 127 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 128 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 129 |
+
|
| 130 |
+
def consistency_mask(self, field, A=2):
|
| 131 |
+
offset = field + init_coords(field) # [B, H, W, 2]
|
| 132 |
+
field_ref_, field_tgt_ = field.chunk(2, dim=0)
|
| 133 |
+
field_ref = torch.cat((field_ref_, field_tgt_), dim=0) # order
|
| 134 |
+
field_tgt = torch.cat((field_tgt_, field_ref_), dim=0) # reverse order
|
| 135 |
+
field_tgt_to_ref = bilinear_sample_by_offset(field_tgt.permute(0, 3, 1, 2).contiguous(), offset).permute(0, 2, 3, 1).contiguous()
|
| 136 |
+
field_diff = torch.abs(field_ref + field_tgt_to_ref).sum(dim=-1, keepdim=True) # ref and tgt flow has different sign
|
| 137 |
+
noc_mask = (field_diff < A).to(field_diff.dtype)
|
| 138 |
+
return noc_mask
|
| 139 |
+
|
| 140 |
+
def forward(self, x, self_rpos, field, stereo=True): # self_rpos [B, H, W, h*2], field [B, H, W, 2]
|
| 141 |
+
|
| 142 |
+
field_out = {}
|
| 143 |
+
B, H, W, C = x.shape
|
| 144 |
+
|
| 145 |
+
noc_mask = self.consistency_mask(field.detach())
|
| 146 |
+
|
| 147 |
+
x = torch.cat((x, field*self.field_scale.to(field.dtype), self_rpos), dim=-1).contiguous()
|
| 148 |
+
|
| 149 |
+
coords_0 = init_coords(field).repeat(1, 1, 1, self.num_head)
|
| 150 |
+
self_offset = self_rpos + coords_0
|
| 151 |
+
self_offset = self_offset.view(B, H*W, self.num_head, 2).contiguous()
|
| 152 |
+
|
| 153 |
+
x = x + self.drop_path0(self.match_attention_self(self.norm0(x), self_offset, noc_mask))
|
| 154 |
+
|
| 155 |
+
self_rpos = x[..., -(self.num_head*2):].contiguous() # [B, H, W, h*2]
|
| 156 |
+
x = x[..., :-(self.num_head*2)].contiguous()
|
| 157 |
+
|
| 158 |
+
if stereo: x[..., -1] = 0
|
| 159 |
+
field = x[..., -self.field_dim:].contiguous() / self.field_scale.to(field.dtype)
|
| 160 |
+
field_out['self'] = field.clone()
|
| 161 |
+
|
| 162 |
+
offset = field.repeat(1, 1, 1, self.num_head).contiguous() + coords_0 # [B, H, W, h*2]
|
| 163 |
+
offset = offset.view(B, H*W, self.num_head, 2).contiguous()
|
| 164 |
+
|
| 165 |
+
x = x + self.drop_path1(self.match_attention_cross(self.norm1(x), offset))
|
| 166 |
+
|
| 167 |
+
if stereo: x[..., -1] = 0
|
| 168 |
+
field = x[..., -self.field_dim:].contiguous() / self.field_scale.to(field.dtype)
|
| 169 |
+
field_out['cross'] = field.clone()
|
| 170 |
+
|
| 171 |
+
x = x[..., :-self.field_dim].contiguous() # No field feature in MLP
|
| 172 |
+
|
| 173 |
+
x = x + self.drop_path2(self.mlp(self.norm2(x)))
|
| 174 |
+
|
| 175 |
+
return x, self_rpos, field, field_out
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class MatchAttentionBlock(nn.Module):
|
| 179 |
+
r"""MatchAttention block with multiple match-attention layers
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
def __init__(self, args, dim, win_r=2,
|
| 183 |
+
num_layer=6, num_head=8, head_dim=32,
|
| 184 |
+
mlp=ConvGLU, mlp_ratio=2, field_dim=2,
|
| 185 |
+
norm_layer=LayerNormWithoutBias,
|
| 186 |
+
drop=0., dp_rates=[0.]):
|
| 187 |
+
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.num_head = num_head
|
| 190 |
+
|
| 191 |
+
self.layers = nn.ModuleList()
|
| 192 |
+
for i in range(num_layer):
|
| 193 |
+
layer = MatchAttentionLayer(args, dim, win_r=win_r, num_head=num_head, head_dim=head_dim,
|
| 194 |
+
mlp=mlp, mlp_ratio=mlp_ratio, field_dim=field_dim,
|
| 195 |
+
norm_layer=norm_layer, drop=drop, drop_path=dp_rates[i])
|
| 196 |
+
self.layers.append(layer)
|
| 197 |
+
|
| 198 |
+
def forward(self, x, self_rpos, field, stereo=True):
|
| 199 |
+
fields = []
|
| 200 |
+
B, H, W, C = x.shape
|
| 201 |
+
self_rpos = self_rpos.repeat(1, 1, 1, self.num_head) # [B, H, W, 2] -> [B, H, W, h*2]
|
| 202 |
+
|
| 203 |
+
for layer in self.layers:
|
| 204 |
+
|
| 205 |
+
x, self_rpos, field, field_out = layer(x, self_rpos, field, stereo)
|
| 206 |
+
fields.append(field_out)
|
| 207 |
+
|
| 208 |
+
self_rpos = self_rpos.view(B, H, W, self.num_head, 2).mean(dim=-2, keepdim=False)
|
| 209 |
+
|
| 210 |
+
return x, self_rpos, field, fields
|
models/common.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class UpConv(nn.Module):
|
| 6 |
+
r"""Upsample using transposed conv"""
|
| 7 |
+
|
| 8 |
+
def __init__(self, in_channels, out_channels):
|
| 9 |
+
super().__init__()
|
| 10 |
+
|
| 11 |
+
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0)
|
| 12 |
+
self.conv = nn.Sequential(
|
| 13 |
+
nn.Conv2d(out_channels*2, out_channels, kernel_size=1, padding=0),
|
| 14 |
+
nn.ReLU(inplace=True),
|
| 15 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
def forward(self, x1, x2, use_up=True):
|
| 19 |
+
x1 = x1.permute(0, 3, 1, 2).contiguous()
|
| 20 |
+
x2 = x2.permute(0, 3, 1, 2).contiguous()
|
| 21 |
+
if use_up:
|
| 22 |
+
x1 = self.up(x1)
|
| 23 |
+
x = torch.cat([x2, x1], dim=1)
|
| 24 |
+
out = self.conv(x)
|
| 25 |
+
return out.permute(0, 2, 3, 1).contiguous() # [B, H, W, C]
|
| 26 |
+
|
| 27 |
+
class ConvGLU(nn.Module):
|
| 28 |
+
'''
|
| 29 |
+
Convolutional GLU, referenced from TransNeXt
|
| 30 |
+
'''
|
| 31 |
+
def __init__(self, dim, mlp_ratio=2, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 32 |
+
super().__init__()
|
| 33 |
+
in_features = dim
|
| 34 |
+
out_features = out_features or in_features
|
| 35 |
+
hidden_features = int(mlp_ratio * in_features)
|
| 36 |
+
self.fc1 = nn.Linear(in_features, hidden_features * 2)
|
| 37 |
+
self.dwconv = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, bias=True, groups=hidden_features)
|
| 38 |
+
self.act = act_layer()
|
| 39 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 40 |
+
self.drop = nn.Dropout(drop)
|
| 41 |
+
|
| 42 |
+
def forward(self, x): # [B, H, W, C]
|
| 43 |
+
x, v = self.fc1(x).chunk(2, dim=-1)
|
| 44 |
+
x = self.act(self.dwconv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1).contiguous()) * v
|
| 45 |
+
x = self.drop(x)
|
| 46 |
+
x = self.fc2(x)
|
| 47 |
+
x = self.drop(x)
|
| 48 |
+
return x
|
models/compile.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
rm -rf build/ dist/ match_attention.egg-info/ __pycache__
|
| 3 |
+
python setup.py clean
|
| 4 |
+
pip install .
|
models/convformer.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 7 |
+
from timm.models.registry import register_model
|
| 8 |
+
from timm.models.layers.helpers import to_2tuple
|
| 9 |
+
class LayerNormGeneral(nn.Module):
|
| 10 |
+
r""" General LayerNorm for different situations.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
affine_shape (int, list or tuple): The shape of affine weight and bias.
|
| 14 |
+
Usually the affine_shape=C, but in some implementation, like torch.nn.LayerNorm,
|
| 15 |
+
the affine_shape is the same as normalized_dim by default.
|
| 16 |
+
To adapt to different situations, we offer this argument here.
|
| 17 |
+
normalized_dim (tuple or list): Which dims to compute mean and variance.
|
| 18 |
+
scale (bool): Flag indicates whether to use scale or not.
|
| 19 |
+
bias (bool): Flag indicates whether to use scale or not.
|
| 20 |
+
|
| 21 |
+
We give several examples to show how to specify the arguments.
|
| 22 |
+
|
| 23 |
+
LayerNorm (https://arxiv.org/abs/1607.06450):
|
| 24 |
+
For input shape of (B, *, C) like (B, N, C) or (B, H, W, C),
|
| 25 |
+
affine_shape=C, normalized_dim=(-1, ), scale=True, bias=True;
|
| 26 |
+
For input shape of (B, C, H, W),
|
| 27 |
+
affine_shape=(C, 1, 1), normalized_dim=(1, ), scale=True, bias=True.
|
| 28 |
+
|
| 29 |
+
Modified LayerNorm (https://arxiv.org/abs/2111.11418)
|
| 30 |
+
that is idental to partial(torch.nn.GroupNorm, num_groups=1):
|
| 31 |
+
For input shape of (B, N, C),
|
| 32 |
+
affine_shape=C, normalized_dim=(1, 2), scale=True, bias=True;
|
| 33 |
+
For input shape of (B, H, W, C),
|
| 34 |
+
affine_shape=C, normalized_dim=(1, 2, 3), scale=True, bias=True;
|
| 35 |
+
For input shape of (B, C, H, W),
|
| 36 |
+
affine_shape=(C, 1, 1), normalized_dim=(1, 2, 3), scale=True, bias=True.
|
| 37 |
+
|
| 38 |
+
For the several metaformer baslines,
|
| 39 |
+
IdentityFormer, RandFormer and PoolFormerV2 utilize Modified LayerNorm without bias (bias=False);
|
| 40 |
+
ConvFormer and CAFormer utilizes LayerNorm without bias (bias=False).
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, affine_shape=None, normalized_dim=(-1, ), scale=True,
|
| 44 |
+
bias=False, eps=1e-6):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.normalized_dim = normalized_dim
|
| 47 |
+
self.use_scale = scale
|
| 48 |
+
self.use_bias = bias
|
| 49 |
+
self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else None
|
| 50 |
+
self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else None
|
| 51 |
+
self.eps = eps
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
c = x - x.mean(self.normalized_dim, keepdim=True)
|
| 55 |
+
s = c.pow(2).mean(self.normalized_dim, keepdim=True)
|
| 56 |
+
x = c / torch.sqrt(s + self.eps)
|
| 57 |
+
if self.use_scale:
|
| 58 |
+
x = x * self.weight
|
| 59 |
+
if self.use_bias:
|
| 60 |
+
x = x + self.bias
|
| 61 |
+
return x
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def stem(in_chs, out_chs, act_layer=nn.GELU):
|
| 66 |
+
return nn.Sequential(
|
| 67 |
+
nn.Conv2d(in_chs, out_chs // 2, kernel_size=3, stride=2, padding=1),
|
| 68 |
+
## nn.BatchNorm2d(out_chs // 2),
|
| 69 |
+
nn.InstanceNorm2d(out_chs // 2),
|
| 70 |
+
act_layer(),
|
| 71 |
+
nn.Conv2d(out_chs // 2, out_chs, kernel_size=3, stride=2, padding=1),
|
| 72 |
+
## nn.BatchNorm2d(out_chs),
|
| 73 |
+
nn.InstanceNorm2d(out_chs),
|
| 74 |
+
act_layer(),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
class Downsampling(nn.Module):
|
| 78 |
+
"""
|
| 79 |
+
Downsampling implemented by a layer of convolution.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(self, in_channels, out_channels,
|
| 83 |
+
kernel_size=3, stride=2, padding=1,
|
| 84 |
+
pre_norm=LayerNormGeneral, post_norm=None, pre_permute=True):
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity()
|
| 87 |
+
self.pre_permute = pre_permute
|
| 88 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
|
| 89 |
+
stride=stride, padding=padding)
|
| 90 |
+
self.post_norm = post_norm(
|
| 91 |
+
out_channels) if post_norm else nn.Identity()
|
| 92 |
+
|
| 93 |
+
def forward(self, x):
|
| 94 |
+
x = self.pre_norm(x)
|
| 95 |
+
if self.pre_permute:
|
| 96 |
+
x = x.permute(0, 3, 1, 2).contiguous() # if take [B, H, W, C] as input, permute it to [B, C, H, W]
|
| 97 |
+
x = self.conv(x)
|
| 98 |
+
x = x.permute(0, 2, 3, 1).contiguous() # [B, C, H, W] -> [B, H, W, C]
|
| 99 |
+
x = self.post_norm(x)
|
| 100 |
+
return x
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Scale(nn.Module):
|
| 104 |
+
"""
|
| 105 |
+
Scale vector by element multiplications.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(self, dim, init_value=1.0, trainable=True):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.scale = nn.Parameter(
|
| 111 |
+
init_value * torch.ones(dim), requires_grad=trainable)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
return x * self.scale
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class LayerNormWithoutBias(nn.Module):
|
| 118 |
+
"""
|
| 119 |
+
Equal to partial(LayerNormGeneral, bias=False) but faster,
|
| 120 |
+
because it directly utilizes otpimized F.layer_norm
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
def __init__(self, normalized_shape, eps=1e-5, **kwargs):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.eps = eps
|
| 126 |
+
self.bias = None
|
| 127 |
+
if isinstance(normalized_shape, int):
|
| 128 |
+
normalized_shape = (normalized_shape,)
|
| 129 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 130 |
+
self.normalized_shape = normalized_shape
|
| 131 |
+
|
| 132 |
+
def forward(self, x):
|
| 133 |
+
return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class SepConv(nn.Module):
|
| 137 |
+
r"""
|
| 138 |
+
Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381.
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
def __init__(self, dim, expansion_ratio=2,
|
| 142 |
+
act1_layer=nn.GELU, act2_layer=nn.Identity,
|
| 143 |
+
bias=False, kernel_size=3, padding=1,
|
| 144 |
+
**kwargs, ):
|
| 145 |
+
super().__init__()
|
| 146 |
+
med_channels = int(expansion_ratio * dim)
|
| 147 |
+
self.pwconv1 = nn.Linear(dim, med_channels, bias=bias)
|
| 148 |
+
self.act1 = act1_layer()
|
| 149 |
+
self.dwconv = nn.Conv2d(
|
| 150 |
+
med_channels, med_channels, kernel_size=kernel_size,
|
| 151 |
+
padding=padding, groups=med_channels, bias=bias) # depthwise conv
|
| 152 |
+
self.act2 = act2_layer()
|
| 153 |
+
self.pwconv2 = nn.Linear(med_channels, dim, bias=bias)
|
| 154 |
+
|
| 155 |
+
def forward(self, x):
|
| 156 |
+
x = self.pwconv1(x)
|
| 157 |
+
x = self.act1(x)
|
| 158 |
+
x = x.permute(0, 3, 1, 2)
|
| 159 |
+
x = self.dwconv(x)
|
| 160 |
+
x = x.permute(0, 2, 3, 1)
|
| 161 |
+
x = self.act2(x)
|
| 162 |
+
x = self.pwconv2(x)
|
| 163 |
+
return x
|
| 164 |
+
|
| 165 |
+
class Mlp(nn.Module):
|
| 166 |
+
""" MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks.
|
| 167 |
+
Mostly copied from timm.
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=nn.GELU, drop=0., bias=False, **kwargs):
|
| 171 |
+
super().__init__()
|
| 172 |
+
in_features = dim
|
| 173 |
+
out_features = out_features or in_features
|
| 174 |
+
hidden_features = int(mlp_ratio * in_features)
|
| 175 |
+
drop_probs = to_2tuple(drop)
|
| 176 |
+
|
| 177 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 178 |
+
self.act = act_layer()
|
| 179 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
| 180 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 181 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 182 |
+
|
| 183 |
+
def forward(self, x):
|
| 184 |
+
x = self.fc1(x)
|
| 185 |
+
x = self.act(x)
|
| 186 |
+
x = self.drop1(x)
|
| 187 |
+
x = self.fc2(x)
|
| 188 |
+
x = self.drop2(x)
|
| 189 |
+
return x
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class MetaFormerBlock(nn.Module):
|
| 193 |
+
"""
|
| 194 |
+
Implementation of one MetaFormer block.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
def __init__(self, dim,
|
| 198 |
+
token_mixer=nn.Identity, mlp=Mlp, mlp_ratio=4,
|
| 199 |
+
norm_layer=nn.LayerNorm, drop=0., drop_path=0.,
|
| 200 |
+
layer_scale_init_value=None, res_scale_init_value=None
|
| 201 |
+
):
|
| 202 |
+
|
| 203 |
+
super().__init__()
|
| 204 |
+
|
| 205 |
+
self.token_mixer = token_mixer(dim, drop=drop)
|
| 206 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 207 |
+
self.norm1 = norm_layer(dim)
|
| 208 |
+
self.layer_scale1 = Scale(dim=dim, init_value=layer_scale_init_value) \
|
| 209 |
+
if layer_scale_init_value else nn.Identity()
|
| 210 |
+
self.res_scale1 = Scale(dim=dim, init_value=res_scale_init_value) \
|
| 211 |
+
if res_scale_init_value else nn.Identity()
|
| 212 |
+
|
| 213 |
+
self.norm2 = norm_layer(dim)
|
| 214 |
+
self.mlp = mlp(dim=dim, mlp_ratio=mlp_ratio, drop=drop)
|
| 215 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 216 |
+
self.layer_scale2 = Scale(dim=dim, init_value=layer_scale_init_value) \
|
| 217 |
+
if layer_scale_init_value else nn.Identity()
|
| 218 |
+
self.res_scale2 = Scale(dim=dim, init_value=res_scale_init_value) \
|
| 219 |
+
if res_scale_init_value else nn.Identity()
|
| 220 |
+
|
| 221 |
+
def forward(self, x):
|
| 222 |
+
x = x + self.drop_path1(self.token_mixer(self.norm1(x)))
|
| 223 |
+
x = x + self.drop_path2(self.mlp(self.norm2(x)))
|
| 224 |
+
return x
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class MetaFormer(nn.Module):
|
| 228 |
+
r""" MetaFormer
|
| 229 |
+
A PyTorch impl of : `MetaFormer Baselines for Vision` -
|
| 230 |
+
https://arxiv.org/abs/2210.13452
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 234 |
+
num_classes (int): Number of classes for classification head. Default: 1000.
|
| 235 |
+
depths (list or tuple): Number of blocks at each stage. Default: [2, 2, 6, 2].
|
| 236 |
+
dims (int): Feature dimension at each stage. Default: [64, 128, 320, 512].
|
| 237 |
+
downsample_layers: (list or tuple): Downsampling layers before each stage.
|
| 238 |
+
token_mixers (list, tuple or token_fcn): Token mixer for each stage. Default: nn.Identity.
|
| 239 |
+
mlps (list, tuple or mlp_fcn): Mlp for each stage. Default: Mlp.
|
| 240 |
+
norm_layers (list, tuple or norm_fcn): Norm layers for each stage. Default: partial(LayerNormGeneral, eps=1e-6, bias=False).
|
| 241 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
| 242 |
+
layer_scale_init_values (list, tuple, float or None): Init value for Layer Scale. Default: None.
|
| 243 |
+
None means not use the layer scale. Form: https://arxiv.org/abs/2103.17239.
|
| 244 |
+
res_scale_init_values (list, tuple, float or None): Init value for Layer Scale. Default: [None, None, 1.0, 1.0].
|
| 245 |
+
None means not use the layer scale. From: https://arxiv.org/abs/2110.09456.
|
| 246 |
+
head_fn: classification head. Default: nn.Linear.
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
def __init__(self, in_chans=3, num_classes=1000,
|
| 250 |
+
depths=[2, 2, 6, 2],
|
| 251 |
+
dims=[64, 128, 320, 512],
|
| 252 |
+
downsample_layers=[stem] + [Downsampling]*3,
|
| 253 |
+
token_mixers=nn.Identity,
|
| 254 |
+
mlps=Mlp, mlp_ratio=4,
|
| 255 |
+
norm_layers=partial(LayerNormWithoutBias, eps=1e-6),
|
| 256 |
+
drop_path_rate=0.,
|
| 257 |
+
layer_scale_init_values=None,
|
| 258 |
+
res_scale_init_values=[None, None, 1.0, 1.0],
|
| 259 |
+
head_fn=nn.Linear,
|
| 260 |
+
**kwargs,
|
| 261 |
+
):
|
| 262 |
+
super().__init__()
|
| 263 |
+
self.num_classes = num_classes
|
| 264 |
+
|
| 265 |
+
if not isinstance(depths, (list, tuple)):
|
| 266 |
+
depths = [depths] # it means the model has only one stage
|
| 267 |
+
if not isinstance(dims, (list, tuple)):
|
| 268 |
+
dims = [dims]
|
| 269 |
+
|
| 270 |
+
self.dims = dims
|
| 271 |
+
self.depths = depths
|
| 272 |
+
|
| 273 |
+
num_stage = len(depths)
|
| 274 |
+
self.num_stage = num_stage
|
| 275 |
+
|
| 276 |
+
down_dims = [in_chans] + dims
|
| 277 |
+
self.downsample_layers = nn.ModuleList([downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(num_stage)])
|
| 278 |
+
|
| 279 |
+
if not isinstance(token_mixers, (list, tuple)):
|
| 280 |
+
token_mixers = [token_mixers] * num_stage
|
| 281 |
+
self.token_mixers = token_mixers
|
| 282 |
+
|
| 283 |
+
if not isinstance(mlps, (list, tuple)):
|
| 284 |
+
mlps = [mlps] * num_stage
|
| 285 |
+
|
| 286 |
+
if not isinstance(norm_layers, (list, tuple)):
|
| 287 |
+
norm_layers = [norm_layers] * num_stage
|
| 288 |
+
|
| 289 |
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 290 |
+
|
| 291 |
+
if not isinstance(layer_scale_init_values, (list, tuple)):
|
| 292 |
+
layer_scale_init_values = [layer_scale_init_values] * num_stage
|
| 293 |
+
if not isinstance(res_scale_init_values, (list, tuple)):
|
| 294 |
+
res_scale_init_values = [res_scale_init_values] * num_stage
|
| 295 |
+
|
| 296 |
+
self.stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
|
| 297 |
+
cur = 0
|
| 298 |
+
for i in range(num_stage):
|
| 299 |
+
stage = nn.ModuleList(
|
| 300 |
+
[MetaFormerBlock(dim=dims[i], token_mixer=token_mixers[i],
|
| 301 |
+
mlp=mlps[i], mlp_ratio=mlp_ratio, norm_layer=norm_layers[i],
|
| 302 |
+
drop_path=dp_rates[cur + j],
|
| 303 |
+
layer_scale_init_value=layer_scale_init_values[i],
|
| 304 |
+
res_scale_init_value=res_scale_init_values[i],
|
| 305 |
+
) for j in range(depths[i])]
|
| 306 |
+
)
|
| 307 |
+
self.stages.append(stage)
|
| 308 |
+
cur += depths[i]
|
| 309 |
+
|
| 310 |
+
self.head = head_fn(dims[-1], num_classes)
|
| 311 |
+
|
| 312 |
+
self.apply(self._init_weights)
|
| 313 |
+
|
| 314 |
+
def _init_weights(self, m):
|
| 315 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 316 |
+
trunc_normal_(m.weight, std=.02)
|
| 317 |
+
if m.bias is not None:
|
| 318 |
+
nn.init.constant_(m.bias, 0)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def forward(self, x):
|
| 322 |
+
outs = []
|
| 323 |
+
for i in range(self.num_stage):
|
| 324 |
+
x = self.downsample_layers[i](x)
|
| 325 |
+
if i==0: x = x.permute(0, 2, 3, 1).contiguous() # [B, C, H, W] -> [B, H, W, C]
|
| 326 |
+
for j in range(self.depths[i]):
|
| 327 |
+
x= self.stages[i][j](x)
|
| 328 |
+
outs.append(x) # [B, H, W, C]
|
| 329 |
+
return outs
|
| 330 |
+
|
| 331 |
+
def convformer(variant='tiny'):
|
| 332 |
+
if variant == 'tiny':
|
| 333 |
+
model = convformer_t()
|
| 334 |
+
|
| 335 |
+
elif variant == 'small':
|
| 336 |
+
model = convformer_s()
|
| 337 |
+
|
| 338 |
+
elif variant == 'base':
|
| 339 |
+
model = convformer_b()
|
| 340 |
+
|
| 341 |
+
elif variant == 'large':
|
| 342 |
+
model = convformer_l()
|
| 343 |
+
|
| 344 |
+
else:
|
| 345 |
+
raise NotImplementedError
|
| 346 |
+
|
| 347 |
+
return model
|
| 348 |
+
|
| 349 |
+
@register_model
|
| 350 |
+
def convformer_t(**kwargs):
|
| 351 |
+
model = MetaFormer(
|
| 352 |
+
depths=[2, 2, 6, 2],
|
| 353 |
+
dims=[32, 64, 128, 160],
|
| 354 |
+
mlps=Mlp, mlp_ratio=2,
|
| 355 |
+
token_mixers=[SepConv, SepConv, SepConv, SepConv],
|
| 356 |
+
head_fn=nn.Linear,
|
| 357 |
+
**kwargs)
|
| 358 |
+
return model
|
| 359 |
+
|
| 360 |
+
@register_model
|
| 361 |
+
def convformer_s(**kwargs):
|
| 362 |
+
model = MetaFormer(
|
| 363 |
+
depths=[2, 2, 6, 2],
|
| 364 |
+
dims=[64, 128, 160, 320],
|
| 365 |
+
mlps=Mlp, mlp_ratio=2,
|
| 366 |
+
token_mixers=[SepConv, SepConv, SepConv, SepConv],
|
| 367 |
+
head_fn=nn.Linear,
|
| 368 |
+
**kwargs)
|
| 369 |
+
return model
|
| 370 |
+
|
| 371 |
+
@register_model
|
| 372 |
+
def convformer_b(**kwargs):
|
| 373 |
+
model = MetaFormer(
|
| 374 |
+
depths=[2, 2, 6, 2],
|
| 375 |
+
dims=[128, 256, 320, 512],
|
| 376 |
+
mlps=Mlp, mlp_ratio=2,
|
| 377 |
+
token_mixers=[SepConv, SepConv, SepConv, SepConv],
|
| 378 |
+
head_fn=nn.Linear,
|
| 379 |
+
**kwargs)
|
| 380 |
+
return model
|
| 381 |
+
|
| 382 |
+
@register_model
|
| 383 |
+
def convformer_l(**kwargs):
|
| 384 |
+
model = MetaFormer(
|
| 385 |
+
depths=[2, 2, 6, 2],
|
| 386 |
+
dims=[256, 384, 512, 768],
|
| 387 |
+
mlps=Mlp, mlp_ratio=2,
|
| 388 |
+
token_mixers=[SepConv, SepConv, SepConv, SepConv],
|
| 389 |
+
head_fn=nn.Linear,
|
| 390 |
+
**kwargs)
|
| 391 |
+
return model
|
models/cost_volume.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from models.convformer import LayerNormWithoutBias
|
| 6 |
+
from utils.utils import init_coords
|
| 7 |
+
|
| 8 |
+
class GlobalCorrelation(nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(self, dim):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.norm = LayerNormWithoutBias(dim)
|
| 13 |
+
self.q = nn.Linear(dim, dim, bias=False)
|
| 14 |
+
self.k = nn.Linear(dim, dim, bias=False)
|
| 15 |
+
self.scale = dim**-0.5
|
| 16 |
+
|
| 17 |
+
def forward(self, x, stereo=True):
|
| 18 |
+
x = self.norm(x)
|
| 19 |
+
ref, tgt = x.chunk(2, dim=0)
|
| 20 |
+
ref, tgt = self.q(ref), self.k(tgt)
|
| 21 |
+
# global correlation on horizontal direction
|
| 22 |
+
B, H, W, C = ref.shape
|
| 23 |
+
|
| 24 |
+
if stereo:
|
| 25 |
+
correlation = torch.matmul(ref, tgt.transpose(-2, -1))*self.scale # [B, H, W, W]
|
| 26 |
+
|
| 27 |
+
# mask subsequent positions to make disparity positive
|
| 28 |
+
mask = torch.triu(torch.ones((W, W), dtype=ref.dtype, device=ref.device), diagonal=1) # [W, W]
|
| 29 |
+
valid_mask = (mask == 0).unsqueeze(0).unsqueeze(0).repeat(B, H, 1, 1) # [B, H, W, W]
|
| 30 |
+
|
| 31 |
+
mask_ = torch.triu(torch.ones((W, W), dtype=ref.dtype, device=ref.device), diagonal=0) # mask for input order [right, left]
|
| 32 |
+
valid_mask_ = (mask_ != 0).unsqueeze(0).unsqueeze(0).repeat(B, H, 1, 1) # upper right
|
| 33 |
+
valid_mask = torch.cat((valid_mask, valid_mask_), dim=0) # [B*2, H, W, W]
|
| 34 |
+
correlation = torch.cat((correlation, correlation.permute(0, 1, 3, 2)), dim=0) # [B*2, H, W, W]
|
| 35 |
+
B = B*2
|
| 36 |
+
|
| 37 |
+
correlation[~valid_mask] = -1e9 if correlation.dtype == torch.float32 else -1e4
|
| 38 |
+
|
| 39 |
+
# build volume from correlation
|
| 40 |
+
D = W # all-pair correlation
|
| 41 |
+
volume = correlation.new_zeros([B, D, H, W])
|
| 42 |
+
for d in range(D): # most time-consuming
|
| 43 |
+
volume[:B//2, d, :, d:] = correlation[:B//2, :, range(d, W), range(W-d)]
|
| 44 |
+
volume[B//2:, d, :, :(W-d)] = correlation[B//2:, :, range(W-d), range(d, W)]
|
| 45 |
+
|
| 46 |
+
volume = F.softmax(volume, dim=1).to(volume.dtype)
|
| 47 |
+
|
| 48 |
+
volume_clone = volume.clone()
|
| 49 |
+
for d in range(D): # fill out of view # second time-consuming
|
| 50 |
+
volume_clone[:B//2, d, :, :d] = volume[:B//2, d, :, d:d+1] # left
|
| 51 |
+
volume_clone[B//2:, d, :, W-1-d:] = volume[B//2:, d, :, W-1-d:(W-d)] # right
|
| 52 |
+
|
| 53 |
+
flow = local_disparity_estimator(volume_clone)
|
| 54 |
+
return flow, volume_clone
|
| 55 |
+
else:
|
| 56 |
+
init_grid = init_coords(ref) # [B, H, W, 2]
|
| 57 |
+
ref = ref.view(B, -1, C) # [B, H*W, C]
|
| 58 |
+
tgt = tgt.view(B, -1, C) # [B, H*W, C]
|
| 59 |
+
|
| 60 |
+
correlation = torch.matmul(ref, tgt.transpose(-2, -1))*self.scale # [B, H*W, H*W]
|
| 61 |
+
correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W]
|
| 62 |
+
init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, H, W, 2]
|
| 63 |
+
B = B * 2
|
| 64 |
+
|
| 65 |
+
prob = F.softmax(correlation, dim=-1).to(correlation.dtype) # [B, H*W, H*W]
|
| 66 |
+
|
| 67 |
+
flow = local_flow_estimator(prob, init_grid)
|
| 68 |
+
|
| 69 |
+
return flow, prob.view(B, H, W, H*W)
|
| 70 |
+
|
| 71 |
+
def local_flow_estimator(prob, init_grid, k=5):
|
| 72 |
+
"""
|
| 73 |
+
Flow estimator using weighted sum within local window centered at max prob
|
| 74 |
+
Args:
|
| 75 |
+
prob: normalized correlation volume [B, H*W, H*W]
|
| 76 |
+
init_grid: init coordinate grid [B, H, W, 2]
|
| 77 |
+
k: local window size (odd number)
|
| 78 |
+
Returns:
|
| 79 |
+
flow: optical field [B, H, W, 2]
|
| 80 |
+
"""
|
| 81 |
+
B, H, W, _ = init_grid.shape
|
| 82 |
+
r = k // 2
|
| 83 |
+
device = prob.device
|
| 84 |
+
|
| 85 |
+
prob_blur = F.avg_pool2d(prob, kernel_size=k, stride=1, padding=r).view(B, H*W, H*W)
|
| 86 |
+
|
| 87 |
+
max_prob, max_idx = torch.max(prob_blur, dim=-1) # [B, H*W]
|
| 88 |
+
max_idx = max_idx.unsqueeze(-1) # [B, H*W, 1]
|
| 89 |
+
target_coords = init_grid # [B, H, W, 2]
|
| 90 |
+
max_y = max_idx // W # [B, H*W, 1]
|
| 91 |
+
max_x = max_idx % W # [B, H*W, 1]
|
| 92 |
+
max_y = torch.clamp(max_y, r, H-1-r)
|
| 93 |
+
max_x = torch.clamp(max_x, r, W-1-r)
|
| 94 |
+
|
| 95 |
+
yy, xx = torch.meshgrid(torch.arange(-r, r+1, device=device), torch.arange(-r, r+1, device=device), indexing='ij')
|
| 96 |
+
offsets_y = yy.reshape(1, 1, k*k, 1) # [1, 1, k*k, 1]
|
| 97 |
+
offsets_x = xx.reshape(1, 1, k*k, 1) # [1, 1, k*k, 1]
|
| 98 |
+
sample_y = max_y.unsqueeze(2) + offsets_y # [B, H*W, k*k, 1]
|
| 99 |
+
sample_x = max_x.unsqueeze(2) + offsets_x # [B, H*W, k*k, 1]
|
| 100 |
+
sample_y = sample_y.long().squeeze(-1) # [B, H*W, k*k]
|
| 101 |
+
sample_x = sample_x.long().squeeze(-1) # [B, H*W, k*k]
|
| 102 |
+
|
| 103 |
+
batch_idx = torch.arange(B, device=device).view(B, 1, 1).expand(-1, H*W, k*k)
|
| 104 |
+
window_coords = target_coords[batch_idx, sample_y, sample_x] # [B, H*W, k*k, 2]
|
| 105 |
+
|
| 106 |
+
window_indices = sample_y * W + sample_x # [B, H*W, k*k]
|
| 107 |
+
window_probs = torch.gather(prob, dim=-1, index=window_indices) # [B, H*W, k*k]
|
| 108 |
+
|
| 109 |
+
mean_prob = 1.0 / (H * W)
|
| 110 |
+
invalid_mask = window_probs < mean_prob
|
| 111 |
+
window_probs[invalid_mask] = 0
|
| 112 |
+
|
| 113 |
+
window_probs_sum = window_probs.sum(dim=-1, keepdim=True).to(window_probs.dtype)
|
| 114 |
+
window_probs_sum = torch.clamp(window_probs_sum, min=torch.finfo(window_probs_sum.dtype).tiny)
|
| 115 |
+
normalized_probs = window_probs / window_probs_sum # [B, H*W, k*k]
|
| 116 |
+
normalized_probs = normalized_probs.unsqueeze(-1) # [B, H*W, k*k, 1]
|
| 117 |
+
correspondence = torch.sum(normalized_probs * window_coords, dim=2).to(normalized_probs.dtype) # [B, H*W, 2]
|
| 118 |
+
correspondence = correspondence.view(B, H, W, 2) # [B, H, W, 2]
|
| 119 |
+
flow = correspondence - init_grid
|
| 120 |
+
|
| 121 |
+
return flow
|
| 122 |
+
|
| 123 |
+
def local_disparity_estimator(cv, k=5):
|
| 124 |
+
"""
|
| 125 |
+
Disparity estimator using weighted sum within local window centered at max prob
|
| 126 |
+
Args:
|
| 127 |
+
cv: cost volume [B, D, H, W]
|
| 128 |
+
k: local window size (odd number)
|
| 129 |
+
Returns:
|
| 130 |
+
flow: [B, H, W, 2]
|
| 131 |
+
"""
|
| 132 |
+
B, D, H, W = cv.shape
|
| 133 |
+
r = k // 2
|
| 134 |
+
device = cv.device
|
| 135 |
+
|
| 136 |
+
cv_blur = F.avg_pool1d(cv.permute(0, 2, 3, 1).view(B, -1, D), kernel_size=k, stride=1, padding=r).view(B, H, W, D).permute(0, 3, 1, 2)
|
| 137 |
+
|
| 138 |
+
# find max idx in blured cv
|
| 139 |
+
max_cv, max_idx = torch.max(cv_blur, dim=1) # max_idx: [B, H, W]
|
| 140 |
+
max_idx = max_idx.unsqueeze(1) # [B, 1, H, W]
|
| 141 |
+
max_idx = torch.clamp(max_idx, r, D-1-r) # [B, 1, H, W]
|
| 142 |
+
|
| 143 |
+
offsets = torch.arange(-r, r+1, device=device).view(1, k, 1, 1) # [1, k, 1, 1]
|
| 144 |
+
|
| 145 |
+
sample_idx = max_idx + offsets # [B, k, H, W]
|
| 146 |
+
sample_idx = torch.clamp(sample_idx, 0, D-1)
|
| 147 |
+
|
| 148 |
+
batch_idx = torch.arange(B, device=device).view(B, 1, 1, 1).expand(-1, k, H, W)
|
| 149 |
+
h_idx = torch.arange(H, device=device).view(1, 1, H, 1).expand(B, k, H, W)
|
| 150 |
+
w_idx = torch.arange(W, device=device).view(1, 1, 1, W).expand(B, k, H, W)
|
| 151 |
+
|
| 152 |
+
window_probs = cv[batch_idx, sample_idx, h_idx, w_idx] # [B, k, H, W]
|
| 153 |
+
|
| 154 |
+
mean_prob = 1.0 / D
|
| 155 |
+
invalid_mask = window_probs < mean_prob
|
| 156 |
+
window_probs[invalid_mask] = 0
|
| 157 |
+
|
| 158 |
+
# normalize within local window
|
| 159 |
+
window_probs_sum = window_probs.sum(dim=1, keepdim=True).to(window_probs.dtype) # [B, 1, H, W]
|
| 160 |
+
window_probs_sum = torch.clamp(window_probs_sum, min=torch.finfo(window_probs_sum.dtype).tiny)
|
| 161 |
+
normalized_probs = window_probs / window_probs_sum # [B, k, H, W]
|
| 162 |
+
|
| 163 |
+
window_disp = sample_idx.to(normalized_probs.dtype) # [B, k, H, W]
|
| 164 |
+
|
| 165 |
+
disp = torch.sum(normalized_probs * window_disp, dim=1).to(normalized_probs.dtype).unsqueeze(-1) # [B, H, W, 1]
|
| 166 |
+
|
| 167 |
+
return disp_to_flow(disp, B)
|
| 168 |
+
|
| 169 |
+
def disp_to_flow(disp, B):
|
| 170 |
+
## disp[:B//2, ...] = -disp[:B//2, ...] # negetive left flow
|
| 171 |
+
|
| 172 |
+
## for onnx support
|
| 173 |
+
batch_indices = torch.arange(B, device=disp.device)
|
| 174 |
+
mask = batch_indices < (B // 2)
|
| 175 |
+
|
| 176 |
+
disp = torch.where(mask.view(B, 1, 1, 1), -disp, disp)
|
| 177 |
+
|
| 178 |
+
flow = torch.cat((disp, torch.zeros_like(disp)), dim=-1).contiguous() # [B, H, W, 2]
|
| 179 |
+
return flow
|
models/mat_pytorch_impl.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def compute_bilinear_weights(grid):
|
| 4 |
+
"""
|
| 5 |
+
Compute bilinear weights for BilinearSoftmax
|
| 6 |
+
Args:
|
| 7 |
+
grid: [..., 2], (x, y)
|
| 8 |
+
Returns:
|
| 9 |
+
weights: [..., 4], [nw, ne, sw, se]
|
| 10 |
+
"""
|
| 11 |
+
x = grid[..., 0]
|
| 12 |
+
y = grid[..., 1]
|
| 13 |
+
|
| 14 |
+
x0 = torch.floor(x)
|
| 15 |
+
y0 = torch.floor(y)
|
| 16 |
+
|
| 17 |
+
dx = x - x0
|
| 18 |
+
dy = y - y0
|
| 19 |
+
|
| 20 |
+
nw = (1 - dx) * (1 - dy)
|
| 21 |
+
ne = dx * (1 - dy)
|
| 22 |
+
sw = (1 - dx) * dy
|
| 23 |
+
se = dx * dy
|
| 24 |
+
|
| 25 |
+
weights = torch.stack([nw, ne, sw, se], dim=-1)
|
| 26 |
+
|
| 27 |
+
return weights
|
| 28 |
+
|
| 29 |
+
def compute_match_attention(q, k, m_id, win_r, H, W):
|
| 30 |
+
"""
|
| 31 |
+
Args:
|
| 32 |
+
q: [B, N, h, C] # Query tensor
|
| 33 |
+
k: [B, N, h, C] # Key tensor
|
| 34 |
+
m_id: [B, N, h, 2] # Sampling centers, last dim is (x, y)
|
| 35 |
+
r: int # Sampling window radius
|
| 36 |
+
H: int # Height
|
| 37 |
+
W: int # Width
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
output: [B, N, h, M] where M = (2*win_r[0]+2)*(2*win_r[1]+2)
|
| 41 |
+
"""
|
| 42 |
+
B, N, h, C = q.shape
|
| 43 |
+
M = (2*win_r[0] + 2)*(2*win_r[1] + 2)
|
| 44 |
+
|
| 45 |
+
dx = torch.arange(-win_r[0], win_r[0] + 2, device=q.device, dtype=torch.long)
|
| 46 |
+
dy = torch.arange(-win_r[1], win_r[1] + 2, device=q.device, dtype=torch.long)
|
| 47 |
+
dy, dx = torch.meshgrid(dy, dx, indexing='ij')
|
| 48 |
+
offsets = torch.stack((dx, dy), dim=-1).reshape(M, 2) # [M, 2]
|
| 49 |
+
|
| 50 |
+
centers = m_id.unsqueeze(3) # [B, N, h, 1, 2]
|
| 51 |
+
offsets = offsets.view(1, 1, 1, M, 2) # [1, 1, 1, M, 2]
|
| 52 |
+
coords = centers + offsets # [B, N, h, M, 2]
|
| 53 |
+
|
| 54 |
+
x_coords = coords[..., 0] # [B, N, h, M]
|
| 55 |
+
y_coords = coords[..., 1] # [B, N, h, M]
|
| 56 |
+
|
| 57 |
+
# Clamp coordinates to valid range
|
| 58 |
+
x_coords = x_coords.clamp(0, W-1)
|
| 59 |
+
y_coords = y_coords.clamp(0, H-1)
|
| 60 |
+
|
| 61 |
+
indices = y_coords * W + x_coords # [B, N, h, M]
|
| 62 |
+
|
| 63 |
+
# [B, N, h, C] -> [B, N, h, M, C]
|
| 64 |
+
k_expanded = k.unsqueeze(3).expand(-1, -1, -1, M, -1)
|
| 65 |
+
|
| 66 |
+
# [B, N, h, M] -> [B, N, h, M, C]
|
| 67 |
+
indices_gather = indices.unsqueeze(-1).expand(-1, -1, -1, -1, C)
|
| 68 |
+
|
| 69 |
+
# [B, N, h, M, C]
|
| 70 |
+
k_sampled = torch.gather(k_expanded, dim=1, index=indices_gather)
|
| 71 |
+
|
| 72 |
+
# [B, N, h, M, C] -> [B, N, h, M]
|
| 73 |
+
# negative L1 norm
|
| 74 |
+
output = -torch.abs(q.unsqueeze(3) - k_sampled).sum(dim=-1)
|
| 75 |
+
|
| 76 |
+
return output, indices_gather
|
| 77 |
+
|
| 78 |
+
def attn_scatter(attn, win_r):
|
| 79 |
+
"""
|
| 80 |
+
Scatter the attn to four sub-windows
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
attn: [B, N, h, M], M = (2*win_r[0]+2) * (2*win_r[1]+2)
|
| 84 |
+
win_r: window radius
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
attn_sub: [B, N, h, 4, M_sub] attn for four sub-windows
|
| 88 |
+
"""
|
| 89 |
+
B, N, h, M = attn.shape
|
| 90 |
+
M_sub = (2*win_r[0] + 1)*(2*win_r[1] + 1)
|
| 91 |
+
|
| 92 |
+
# [B, N, h, H_win, W_win]
|
| 93 |
+
attn_2d = attn.view(B, N, h, 2*win_r[0] + 2, 2*win_r[1] + 2)
|
| 94 |
+
|
| 95 |
+
# nw [0, 0] offset
|
| 96 |
+
win_nw = attn_2d[..., :2*win_r[0]+1, :2*win_r[1]+1]
|
| 97 |
+
# ne [1, 0] offset
|
| 98 |
+
win_ne = attn_2d[..., :2*win_r[0]+1, 1:2*win_r[1]+2]
|
| 99 |
+
# sw [0, 1] offset
|
| 100 |
+
win_sw = attn_2d[..., 1:2*win_r[0]+2, :2*win_r[1]+1]
|
| 101 |
+
# se [1, 1] offset
|
| 102 |
+
win_se = attn_2d[..., 1:2*win_r[0]+2, 1:2*win_r[1]+2]
|
| 103 |
+
|
| 104 |
+
win_nw = win_nw.reshape(B, N, h, M_sub)
|
| 105 |
+
win_ne = win_ne.reshape(B, N, h, M_sub)
|
| 106 |
+
win_sw = win_sw.reshape(B, N, h, M_sub)
|
| 107 |
+
win_se = win_se.reshape(B, N, h, M_sub)
|
| 108 |
+
|
| 109 |
+
attn_sub = torch.stack([win_nw, win_ne, win_sw, win_se], dim=3)
|
| 110 |
+
|
| 111 |
+
return attn_sub
|
| 112 |
+
|
| 113 |
+
def attn_gather(attn_sub, win_r):
|
| 114 |
+
"""
|
| 115 |
+
Gather the four attn_sub to attn
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
attn_sub: [B, N, h, 4, M_sub]
|
| 119 |
+
win_r: window radius
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
merged_attn: [B, N, h, M]
|
| 123 |
+
"""
|
| 124 |
+
B, N, h, _, M_sub = attn_sub.shape
|
| 125 |
+
|
| 126 |
+
merged = torch.zeros(B, N, h, 2*win_r[0] + 2, 2*win_r[1] + 2, device=attn_sub.device, dtype=attn_sub.dtype)
|
| 127 |
+
|
| 128 |
+
# nw [0, 0] offset
|
| 129 |
+
win_nw = attn_sub[:, :, :, 0, :].view(B, N, h, 2*win_r[0]+1, 2*win_r[1]+1)
|
| 130 |
+
merged[..., :2*win_r[0]+1, :2*win_r[1]+1] += win_nw
|
| 131 |
+
|
| 132 |
+
# ne [1, 0] offset
|
| 133 |
+
win_ne = attn_sub[:, :, :, 1, :].view(B, N, h, 2*win_r[0]+1, 2*win_r[1]+1)
|
| 134 |
+
merged[..., :2*win_r[0]+1, 1:2*win_r[1]+2] += win_ne
|
| 135 |
+
|
| 136 |
+
# sw [0, 1] offset
|
| 137 |
+
win_sw = attn_sub[:, :, :, 2, :].view(B, N, h, 2*win_r[0]+1, 2*win_r[1]+1)
|
| 138 |
+
merged[..., 1:2*win_r[0]+2, :2*win_r[1]+1] += win_sw
|
| 139 |
+
|
| 140 |
+
# se [1, 1] offset
|
| 141 |
+
win_se = attn_sub[:, :, :, 3, :].view(B, N, h, 2*win_r[0]+1, 2*win_r[1]+1)
|
| 142 |
+
merged[..., 1:2*win_r[0]+2, 1:2*win_r[1]+2] += win_se
|
| 143 |
+
|
| 144 |
+
merged_attn = merged.view(B, N, h, -1)
|
| 145 |
+
|
| 146 |
+
return merged_attn
|
| 147 |
+
|
| 148 |
+
def compute_bilinear_softmax(attn, bilinear_weight, win_r):
|
| 149 |
+
"""
|
| 150 |
+
Blinear Softmax: Attention sampled on a contiguous position
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
attn: [B, N, h, M] attention on discreate position
|
| 154 |
+
win_r: window radius
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
output: [B, N, h, M] effective attention on contiguous position
|
| 158 |
+
"""
|
| 159 |
+
attn_sub = attn_scatter(attn, win_r) # [B, N, h, 4, M_sub]
|
| 160 |
+
|
| 161 |
+
attn_weighted = bilinear_weight.unsqueeze(-1)*attn_sub.softmax(dim=-1)
|
| 162 |
+
|
| 163 |
+
output = attn_gather(attn_weighted, win_r) # [B, N, h, M]
|
| 164 |
+
|
| 165 |
+
return output
|
| 166 |
+
|
| 167 |
+
def attention_aggregate(v, attn, indices_gather, win_r):
|
| 168 |
+
|
| 169 |
+
B, N, h, C = v.shape
|
| 170 |
+
M = (2*win_r[0] + 2)*(2*win_r[1] + 2)
|
| 171 |
+
|
| 172 |
+
# [B, N, h, C] -> [B, N, h, M, C]
|
| 173 |
+
v_expanded = v.unsqueeze(3).expand(-1, -1, -1, M, -1)
|
| 174 |
+
v_sampled = torch.gather(v_expanded, dim=1, index=indices_gather)
|
| 175 |
+
|
| 176 |
+
output = (attn.unsqueeze(-1)*v_sampled).sum(dim=3)
|
| 177 |
+
|
| 178 |
+
return output.view(B, N, -1)
|
models/match_former_ops.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@torch.library.custom_op("match_attention::fused_forward_ops", mutates_args={"output", "attn_out"})
|
| 7 |
+
def fused_forward_ops(
|
| 8 |
+
max_offset: torch.Tensor,
|
| 9 |
+
q: torch.Tensor,
|
| 10 |
+
k: torch.Tensor,
|
| 11 |
+
v: torch.Tensor,
|
| 12 |
+
output: torch.Tensor,
|
| 13 |
+
attn_out: torch.Tensor,
|
| 14 |
+
H: int,
|
| 15 |
+
W: int,
|
| 16 |
+
win_r: List[int],
|
| 17 |
+
attn_num: int,
|
| 18 |
+
attn_type: str,
|
| 19 |
+
scale: float
|
| 20 |
+
) -> None:
|
| 21 |
+
"""
|
| 22 |
+
Opaque custom op for fused forward pass that prevents torch.compile tracing.
|
| 23 |
+
|
| 24 |
+
This wrapper ensures that torch.compile treats this as an opaque operation
|
| 25 |
+
and doesn't try to trace into the CUDA kernel internals.
|
| 26 |
+
"""
|
| 27 |
+
# Call the original CUDA extension
|
| 28 |
+
try:
|
| 29 |
+
import match_attention
|
| 30 |
+
match_attention.fused_forward(
|
| 31 |
+
max_offset, q, k, v, output, attn_out,
|
| 32 |
+
H, W, win_r, attn_num, attn_type, scale
|
| 33 |
+
)
|
| 34 |
+
except ImportError:
|
| 35 |
+
# Fallback to torch.ops if direct import fails
|
| 36 |
+
torch.ops.match_attention.fused_forward(
|
| 37 |
+
max_offset, q, k, v, output, attn_out,
|
| 38 |
+
H, W, win_r, attn_num, attn_type, scale
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@fused_forward_ops.register_fake
|
| 43 |
+
def _(max_offset, q, k, v, output, attn_out, H, W, win_r, attn_num, attn_type, scale):
|
| 44 |
+
"""
|
| 45 |
+
Fake implementation for torch.compile that defines tensor shapes and dtypes
|
| 46 |
+
without actually executing the kernel.
|
| 47 |
+
"""
|
| 48 |
+
# Validate input shapes
|
| 49 |
+
B, N, C = q.shape
|
| 50 |
+
h = max_offset.size(2)
|
| 51 |
+
|
| 52 |
+
# Ensure output tensors have correct shapes
|
| 53 |
+
torch._check(output.shape == (B, N, C), lambda: f"output shape mismatch: expected {(B, N, C)}, got {output.shape}")
|
| 54 |
+
torch._check(attn_out.shape == (B, N, h, attn_num), lambda: f"attn_out shape mismatch: expected {(B, N, h, attn_num)}, got {attn_out.shape}")
|
| 55 |
+
|
| 56 |
+
# Ensure output tensors have correct dtypes and devices
|
| 57 |
+
torch._check(output.dtype == q.dtype, lambda: f"output dtype mismatch: expected {q.dtype}, got {output.dtype}")
|
| 58 |
+
torch._check(attn_out.dtype == q.dtype, lambda: f"attn_out dtype mismatch: expected {q.dtype}, got {attn_out.dtype}")
|
| 59 |
+
torch._check(output.device == q.device, lambda: f"output device mismatch: expected {q.device}, got {output.device}")
|
| 60 |
+
torch._check(attn_out.device == q.device, lambda: f"attn_out device mismatch: expected {q.device}, got {attn_out.device}")
|
| 61 |
+
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class MF_FusedForwardOps(nn.Module):
|
| 66 |
+
"""
|
| 67 |
+
Opaque MatchAttention fused forward, optimized for torch.compile
|
| 68 |
+
|
| 69 |
+
This version uses torch.library.custom_op to create opaque custom operators,
|
| 70 |
+
preventing torch.compile from tracing into CUDA kernel internals.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self):
|
| 74 |
+
super().__init__()
|
| 75 |
+
|
| 76 |
+
def forward(
|
| 77 |
+
self,
|
| 78 |
+
max_offset: torch.Tensor,
|
| 79 |
+
q: torch.Tensor,
|
| 80 |
+
k: torch.Tensor,
|
| 81 |
+
v: torch.Tensor,
|
| 82 |
+
H: int,
|
| 83 |
+
W: int,
|
| 84 |
+
win_r: List[int],
|
| 85 |
+
attn_num: int,
|
| 86 |
+
attn_type: str = 'l1_norm',
|
| 87 |
+
scale: float = 1.0
|
| 88 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 89 |
+
"""
|
| 90 |
+
Fused forward
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
max_offset: Offset tensor with shape [B, N, h, 2]
|
| 94 |
+
q: Query tensor with shape [B, N, C]
|
| 95 |
+
k: Key tensor with shape [B, N, C]
|
| 96 |
+
v: Value tensor with shape [B, N, C]
|
| 97 |
+
H: Feature map height
|
| 98 |
+
W: Feature map width
|
| 99 |
+
win_r: Window radius [r_h, r_w]
|
| 100 |
+
attn_num: Number of attention heads
|
| 101 |
+
attn_type: Attention type ('l1_norm' or 'l2_norm')
|
| 102 |
+
scale: Scale factor
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
output: Output features with shape [B, N, C]
|
| 106 |
+
attn_out: Attention weights with shape [B, N, h, attn_num]
|
| 107 |
+
"""
|
| 108 |
+
B, N, C = q.shape
|
| 109 |
+
h = max_offset.size(2)
|
| 110 |
+
|
| 111 |
+
# Create output tensors
|
| 112 |
+
output = torch.zeros_like(v)
|
| 113 |
+
attn_out = q.new_zeros([B, N, h, attn_num])
|
| 114 |
+
|
| 115 |
+
# Call opaque custom operator
|
| 116 |
+
fused_forward_ops(
|
| 117 |
+
max_offset, q, k, v, output, attn_out,
|
| 118 |
+
H, W, win_r, attn_num, attn_type, scale
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
return output, attn_out
|
models/match_stereo.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from timm.models.layers import trunc_normal_
|
| 6 |
+
from models.common import UpConv
|
| 7 |
+
from models.convformer import convformer
|
| 8 |
+
from models.attention_blocks import MatchAttentionBlock
|
| 9 |
+
from models.cost_volume import GlobalCorrelation
|
| 10 |
+
|
| 11 |
+
class MatchStereo(nn.Module):
|
| 12 |
+
def __init__(self, args,
|
| 13 |
+
refine_win_rs=[2, 2, 1, 1], # refine window radius at 1/32, 1/16, 1/8, 1/4
|
| 14 |
+
refine_nums=[8, 8, 8, 2],
|
| 15 |
+
num_heads=[4, 4, 4, 4],
|
| 16 |
+
mlp_ratios=[2, 2, 2, 2],
|
| 17 |
+
drop_path=0.):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.refine_nums = refine_nums
|
| 20 |
+
|
| 21 |
+
self.encoder = convformer(args.variant)
|
| 22 |
+
self.channels = self.encoder.dims[::-1] # resolution low to high
|
| 23 |
+
self.num_heads = num_heads
|
| 24 |
+
self.head_dims = [c//h for c, h in zip(self.channels, self.num_heads)]
|
| 25 |
+
|
| 26 |
+
self.factor = 2
|
| 27 |
+
self.factor_last = 2**(len(self.channels) - len(refine_nums) + 2)
|
| 28 |
+
|
| 29 |
+
self.field_dim = 2 # 2(flow)
|
| 30 |
+
|
| 31 |
+
self.up_decoders = nn.ModuleList()
|
| 32 |
+
self.up_masks = nn.ModuleList()
|
| 33 |
+
for i in range(len(self.channels)):
|
| 34 |
+
if i > 0:
|
| 35 |
+
self.up_decoders.append(UpConv(self.channels[i-1], self.channels[i]))
|
| 36 |
+
self.up_masks.append(
|
| 37 |
+
nn.Sequential(
|
| 38 |
+
nn.Conv2d(self.channels[i-1], self.channels[i-1], 3, padding=1),
|
| 39 |
+
nn.ReLU(inplace=True),
|
| 40 |
+
nn.Conv2d(self.channels[i-1], (self.factor**2)*9, 1, padding=0))
|
| 41 |
+
)
|
| 42 |
+
else:
|
| 43 |
+
self.up_decoders.append(nn.Identity())
|
| 44 |
+
self.up_masks.append(nn.Identity())
|
| 45 |
+
|
| 46 |
+
self.up_masks.append(
|
| 47 |
+
nn.Sequential(
|
| 48 |
+
nn.Conv2d(self.channels[-1], self.channels[-1]*2, 3, padding=1),
|
| 49 |
+
nn.ReLU(inplace=True),
|
| 50 |
+
nn.Conv2d(self.channels[-1]*2, (self.factor_last**2)*9, 1, padding=0)))
|
| 51 |
+
|
| 52 |
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path, sum(refine_nums))]
|
| 53 |
+
# MatchAttention
|
| 54 |
+
self.match_attentions = nn.ModuleList()
|
| 55 |
+
for i in range(len(refine_nums)):
|
| 56 |
+
self.match_attentions.append(
|
| 57 |
+
MatchAttentionBlock(args, self.channels[i], win_r=refine_win_rs[i],
|
| 58 |
+
num_layer=refine_nums[i], num_head=self.num_heads[i], head_dim=self.head_dims[i],
|
| 59 |
+
mlp_ratio=mlp_ratios[i], field_dim=self.field_dim,
|
| 60 |
+
dp_rates=dp_rates[sum(refine_nums[:i]):sum(refine_nums[:i+1])])
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
self.init_correlation_volume = GlobalCorrelation(self.channels[0])
|
| 64 |
+
|
| 65 |
+
self.apply(self._init_weights)
|
| 66 |
+
|
| 67 |
+
def _init_weights(self, m):
|
| 68 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 69 |
+
trunc_normal_(m.weight, std=.02)
|
| 70 |
+
if m.bias is not None:
|
| 71 |
+
nn.init.constant_(m.bias, 0)
|
| 72 |
+
|
| 73 |
+
def upsample_field(self, field, mask, factor):
|
| 74 |
+
''' Upsample field [H/factor, W/factor, D] -> [H, W, D] using convex combination '''
|
| 75 |
+
B, H, W, D = field.shape
|
| 76 |
+
field = field.permute(0, 3, 1, 2)
|
| 77 |
+
mask = mask.view(B, 1, 9, factor, factor, H, W)
|
| 78 |
+
mask = torch.softmax(mask, dim=2).to(mask.dtype)
|
| 79 |
+
up_flow = F.unfold(field*factor, [3,3], padding=1)
|
| 80 |
+
up_flow = up_flow.view(B, D, 9, 1, 1, H, W)
|
| 81 |
+
|
| 82 |
+
up_flow = torch.sum(mask * up_flow, dim=2).to(mask.dtype) # [B, D, 9, factor, factor, H, W]
|
| 83 |
+
up_flow = up_flow.permute(0, 4, 2, 5, 3, 1)
|
| 84 |
+
return up_flow.reshape(B, factor*H, factor*W, D).contiguous()
|
| 85 |
+
|
| 86 |
+
def forward(self, img0, img1, stereo=True, init_flow=None):
|
| 87 |
+
''' Estimate optical flow/disparity between pair of frames, output bi-directional flow/disparity '''
|
| 88 |
+
field_all = []
|
| 89 |
+
|
| 90 |
+
img0 = (2 * (img0 / 255.0) - 1.0).contiguous()
|
| 91 |
+
img1 = (2 * (img1 / 255.0) - 1.0).contiguous()
|
| 92 |
+
|
| 93 |
+
x = torch.cat((img0, img1), dim=0) # cat in batch dim
|
| 94 |
+
|
| 95 |
+
features = self.encoder(x) # [B*2, H, W, C]
|
| 96 |
+
features = features[::-1] # reverse 1/32, 1/16, 1/8, 1/4
|
| 97 |
+
|
| 98 |
+
for i in range(len(features)): # 1/32, 1/16, 1/8, 1/4
|
| 99 |
+
if i==0:
|
| 100 |
+
if init_flow is None:
|
| 101 |
+
init_flow, init_cv = self.init_correlation_volume(features[i], stereo=stereo)
|
| 102 |
+
else:
|
| 103 |
+
init_cv = None
|
| 104 |
+
|
| 105 |
+
field = init_flow.clone() # [B, H, W, 2]
|
| 106 |
+
self_rpos = torch.zeros_like(field)
|
| 107 |
+
else:
|
| 108 |
+
features[i] = self.up_decoders[i](features[i-1], features[i])
|
| 109 |
+
up_mask = self.up_masks[i](features[i-1].permute(0, 3, 1, 2)) # [B, C, H, W]
|
| 110 |
+
self_rpos = self.upsample_field(self_rpos, up_mask, self.factor)
|
| 111 |
+
field = self.upsample_field(field, up_mask, self.factor)
|
| 112 |
+
field_all.append({'self':field})
|
| 113 |
+
|
| 114 |
+
features[i], self_rpos, field, fields = self.match_attentions[i](features[i], self_rpos, field, stereo=stereo)
|
| 115 |
+
field_all.extend(fields)
|
| 116 |
+
|
| 117 |
+
if self.training:
|
| 118 |
+
B = field.shape[0]
|
| 119 |
+
field_up = self.upsample_field(field[:B//2], self.up_masks[-1](features[-1][:B//2].permute(0, 3, 1, 2)), self.factor_last)
|
| 120 |
+
field_up = torch.cat((field_up, field_up), dim=0) # dummy output
|
| 121 |
+
else:
|
| 122 |
+
field_up = self.upsample_field(field, self.up_masks[-1](features[-1].permute(0, 3, 1, 2)), self.factor_last)
|
| 123 |
+
|
| 124 |
+
return {
|
| 125 |
+
'init_flow': init_flow,
|
| 126 |
+
'init_cv': init_cv,
|
| 127 |
+
'field_all': field_all,
|
| 128 |
+
'field_up': field_up,
|
| 129 |
+
'self_rpos': self_rpos,
|
| 130 |
+
}
|
models/setup.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from setuptools import setup
|
| 3 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
|
| 4 |
+
|
| 5 |
+
setup(
|
| 6 |
+
name='match_attention',
|
| 7 |
+
version='0.7',
|
| 8 |
+
description='Match Attention CUDA Extension for PyTorch',
|
| 9 |
+
author='TingmanYan',
|
| 10 |
+
ext_modules=[
|
| 11 |
+
CUDAExtension('match_attention', [
|
| 12 |
+
'src/match_former_cuda.cpp',
|
| 13 |
+
'src/match_former_cuda_kernel.cu',
|
| 14 |
+
'src/match_former_fused_forward.cu',
|
| 15 |
+
]),
|
| 16 |
+
],
|
| 17 |
+
cmdclass={
|
| 18 |
+
'build_ext': BuildExtension
|
| 19 |
+
}
|
| 20 |
+
)
|
models/src/match_former_cuda.cpp
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include <vector>
|
| 3 |
+
#include <pybind11/pybind11.h>
|
| 4 |
+
#include <pybind11/stl.h>
|
| 5 |
+
#include <string>
|
| 6 |
+
#include <ATen/core/op_registration/op_registration.h>
|
| 7 |
+
|
| 8 |
+
// CUDA declarations
|
| 9 |
+
|
| 10 |
+
void mf_fused_forward_cuda(
|
| 11 |
+
at::Tensor max_offset,
|
| 12 |
+
at::Tensor q,
|
| 13 |
+
at::Tensor k,
|
| 14 |
+
at::Tensor v,
|
| 15 |
+
at::Tensor output,
|
| 16 |
+
at::Tensor attn_out,
|
| 17 |
+
const int H,
|
| 18 |
+
const int W,
|
| 19 |
+
const std::vector<int64_t>& win_r,
|
| 20 |
+
const int attn_num,
|
| 21 |
+
const std::string& attn_type,
|
| 22 |
+
const float scale);
|
| 23 |
+
|
| 24 |
+
void mf_fused_forward(
|
| 25 |
+
at::Tensor max_offset,
|
| 26 |
+
at::Tensor q,
|
| 27 |
+
at::Tensor k,
|
| 28 |
+
at::Tensor v,
|
| 29 |
+
at::Tensor output,
|
| 30 |
+
at::Tensor attn_out,
|
| 31 |
+
const int64_t H,
|
| 32 |
+
const int64_t W,
|
| 33 |
+
const std::vector<int64_t>& win_r,
|
| 34 |
+
const int64_t attn_num,
|
| 35 |
+
const std::string& attn_type,
|
| 36 |
+
const double scale)
|
| 37 |
+
{
|
| 38 |
+
mf_fused_forward_cuda(max_offset, q, k, v, output, attn_out, H, W, win_r, attn_num, attn_type, static_cast<float>(scale));
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
| 42 |
+
{
|
| 43 |
+
m.def("fused_forward", &mf_fused_forward, "Fused forward pass (CUDA)");
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
TORCH_LIBRARY(match_attention, m)
|
| 47 |
+
{
|
| 48 |
+
m.def("fused_forward(Tensor max_offset, Tensor q, Tensor k, Tensor v, Tensor(a!) output, Tensor(b!) attn_out, int H, int W, int[] win_r, int attn_num, str attn_type, float scale) -> ()", &mf_fused_forward);
|
| 49 |
+
}
|
models/src/match_former_cuda_kernel.cu
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
|
| 3 |
+
#include <cuda.h>
|
| 4 |
+
#include <cuda_runtime.h>
|
| 5 |
+
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
#include "match_former_fused_forward.hpp"
|
| 9 |
+
|
| 10 |
+
// Fused forward function that combines all operations
|
| 11 |
+
void mf_fused_forward_cuda(
|
| 12 |
+
at::Tensor max_offset,
|
| 13 |
+
at::Tensor q,
|
| 14 |
+
at::Tensor k,
|
| 15 |
+
at::Tensor v,
|
| 16 |
+
at::Tensor output,
|
| 17 |
+
at::Tensor attn_out,
|
| 18 |
+
const int H,
|
| 19 |
+
const int W,
|
| 20 |
+
const std::vector<int64_t>& win_r,
|
| 21 |
+
const int attn_num,
|
| 22 |
+
const std::string& attn_type,
|
| 23 |
+
const float scale)
|
| 24 |
+
{
|
| 25 |
+
match_former_fused_forward(max_offset, q, k, v, output, attn_out, H, W, win_r, attn_num, attn_type, scale);
|
| 26 |
+
}
|
models/src/match_former_fused_forward.cu
ADDED
|
@@ -0,0 +1,628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include <cuda.h>
|
| 3 |
+
#include <cuda_runtime.h>
|
| 4 |
+
#include <vector>
|
| 5 |
+
#include <cassert>
|
| 6 |
+
#include <cfloat>
|
| 7 |
+
#include <cuda_fp16.h>
|
| 8 |
+
#include <cuda_bf16.h>
|
| 9 |
+
#include <ATen/native/cuda/KernelUtils.cuh>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of kernel functions
|
| 12 |
+
template <typename scalar_t>
|
| 13 |
+
__global__ void clip_offset_to_id_k(const scalar_t *const m_offset_d, int *const m_id_d, const int Lh, const int num_heads, const int N, const int H, const int W);
|
| 14 |
+
|
| 15 |
+
template <typename scalar_t>
|
| 16 |
+
__global__ void attn_weight_bilinear_forward_k(const scalar_t* const m_offset_d, scalar_t* const bilinear_weight_d, const int Lh);
|
| 17 |
+
|
| 18 |
+
__global__ void check_max_id_k(int *const m_id_d, const int L, const int N, const int H, const int W, const int num_heads, const int win_x, const int win_y);
|
| 19 |
+
|
| 20 |
+
template <typename scalar_t>
|
| 21 |
+
__global__ void match_attention_l1_norm_forward_k(
|
| 22 |
+
const scalar_t *__restrict__ q_d,
|
| 23 |
+
const scalar_t *__restrict__ k_d,
|
| 24 |
+
scalar_t *__restrict__ attn_d,
|
| 25 |
+
const int *__restrict__ m_id_d,
|
| 26 |
+
const int *__restrict__ offset_d,
|
| 27 |
+
const int L, const int N, const int H, const int W,
|
| 28 |
+
const int C, const int num_heads, const int key_dim,
|
| 29 |
+
const int attn_num, const int attn_numel,
|
| 30 |
+
const bool swap_xy);
|
| 31 |
+
|
| 32 |
+
template <typename scalar_t>
|
| 33 |
+
__global__ void match_attention_dot_product_forward_k(const scalar_t *const q_d, const scalar_t *const k_d, scalar_t *const attn_d, const int *const m_id_d, const int* const offset_d, const int L, const int N, const int H, const int W, const int C, const int num_heads, const int key_dim, const int attn_num, const int attn_numel, const bool swap_xy);
|
| 34 |
+
|
| 35 |
+
template <typename scalar_t>
|
| 36 |
+
__global__ void bilinear_softmax_forward_general_k(scalar_t* const __restrict__ attn_d,
|
| 37 |
+
scalar_t* const __restrict__ attn_out_d,
|
| 38 |
+
scalar_t* const __restrict__ attn_sum_d,
|
| 39 |
+
const scalar_t* const __restrict__ bilinear_weight_d,
|
| 40 |
+
const int* const __restrict__ select_index_d,
|
| 41 |
+
int L, const int num_heads, const int h_attn_num,
|
| 42 |
+
const int attn_num, const int attn_num_sub);
|
| 43 |
+
|
| 44 |
+
template <typename scalar_t>
|
| 45 |
+
__global__ void attention_aggregate_forward_k(
|
| 46 |
+
const scalar_t *__restrict__ v_d,
|
| 47 |
+
scalar_t *__restrict__ out_d,
|
| 48 |
+
const scalar_t *__restrict__ attn_d,
|
| 49 |
+
const int *__restrict__ m_id_d,
|
| 50 |
+
const int* __restrict__ offset_d,
|
| 51 |
+
const int L, const int C, const int num_heads,
|
| 52 |
+
const int key_dim, const int attn_num,
|
| 53 |
+
const bool swap_xy);
|
| 54 |
+
|
| 55 |
+
template <typename scalar_t>
|
| 56 |
+
__global__ void scale_attention_k(scalar_t* attn_d, const scalar_t scale, const int total_size);
|
| 57 |
+
|
| 58 |
+
// Kernel implementations
|
| 59 |
+
template <typename scalar_t>
|
| 60 |
+
__global__ void
|
| 61 |
+
clip_offset_to_id_k(const scalar_t *const m_offset_d, int *const m_id_d, const int Lh, const int num_heads, const int N, const int H, const int W)
|
| 62 |
+
{
|
| 63 |
+
int lh = blockIdx.x * blockDim.x + threadIdx.x;
|
| 64 |
+
if (lh >= Lh)
|
| 65 |
+
return;
|
| 66 |
+
|
| 67 |
+
int l = lh / num_heads;
|
| 68 |
+
int batch_id = l / N;
|
| 69 |
+
int m_x = __float2int_rd(static_cast<float>(m_offset_d[lh*2])); // round to floor
|
| 70 |
+
int m_y = __float2int_rd(static_cast<float>(m_offset_d[lh*2 + 1]));
|
| 71 |
+
if (m_x < 0) m_x = 0;
|
| 72 |
+
if (m_x >= W) m_x = W - 1;
|
| 73 |
+
if (m_y < 0) m_y = 0;
|
| 74 |
+
if (m_y >= H) m_y = H - 1;
|
| 75 |
+
int m_pix_id = m_y * W + m_x;
|
| 76 |
+
int m_id = batch_id * N + m_pix_id;
|
| 77 |
+
m_id_d[lh] = m_id;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
template <typename scalar_t>
|
| 81 |
+
__global__ void
|
| 82 |
+
attn_weight_bilinear_forward_k(const scalar_t* const m_offset_d, scalar_t* const bilinear_weight_d, const int Lh)
|
| 83 |
+
{
|
| 84 |
+
int lh = blockIdx.x * blockDim.x + threadIdx.x;
|
| 85 |
+
if (lh >= Lh)
|
| 86 |
+
return;
|
| 87 |
+
|
| 88 |
+
float ix = static_cast<float>(m_offset_d[lh*2]);
|
| 89 |
+
float iy = static_cast<float>(m_offset_d[lh*2 + 1]);
|
| 90 |
+
int ix_nw = __float2int_rd(ix);
|
| 91 |
+
int iy_nw = __float2int_rd(iy);
|
| 92 |
+
int ix_ne = ix_nw + 1;
|
| 93 |
+
int iy_ne = iy_nw;
|
| 94 |
+
int ix_sw = ix_nw;
|
| 95 |
+
int iy_sw = iy_nw + 1;
|
| 96 |
+
int ix_se = ix_nw + 1;
|
| 97 |
+
int iy_se = iy_nw + 1;
|
| 98 |
+
|
| 99 |
+
float nw = (ix_se - ix) * (iy_se - iy);
|
| 100 |
+
float ne = (ix - ix_sw) * (iy_sw - iy);
|
| 101 |
+
float sw = (ix_ne - ix) * (iy - iy_ne);
|
| 102 |
+
float se = (ix - ix_nw) * (iy - iy_nw);
|
| 103 |
+
bilinear_weight_d[lh*4] = static_cast<scalar_t>(nw);
|
| 104 |
+
bilinear_weight_d[lh*4 + 1] = static_cast<scalar_t>(ne);
|
| 105 |
+
bilinear_weight_d[lh*4 + 2] = static_cast<scalar_t>(sw);
|
| 106 |
+
bilinear_weight_d[lh*4 + 3] = static_cast<scalar_t>(se); // bilinear_weight of shape [B, N, h, 4]
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
// check if the search window range is out of image coordinates
|
| 110 |
+
__forceinline__ __device__ void
|
| 111 |
+
check_within_image_coordinates(int& l_id, const int& N, const int& H, const int& W, const int& win_x, const int& win_y)
|
| 112 |
+
{
|
| 113 |
+
int pix_id = l_id % N;
|
| 114 |
+
int batch_id = l_id / N;
|
| 115 |
+
int x = pix_id % W;
|
| 116 |
+
int y = pix_id / W;
|
| 117 |
+
if (x - win_x < 0)
|
| 118 |
+
x = win_x;
|
| 119 |
+
if (x + (win_x + 1) >= W)
|
| 120 |
+
x = W - 1 - (win_x + 1);
|
| 121 |
+
if (y - win_y < 0)
|
| 122 |
+
y = win_y;
|
| 123 |
+
if (y + (win_y + 1) >= H)
|
| 124 |
+
y = H - 1 - (win_y + 1);
|
| 125 |
+
pix_id = y * W + x;
|
| 126 |
+
l_id = batch_id * N + pix_id;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
__global__ void
|
| 130 |
+
check_max_id_k(int *const m_id_d, const int L, const int N, const int H, const int W, const int num_heads, const int win_x, const int win_y)
|
| 131 |
+
{
|
| 132 |
+
int l, h;
|
| 133 |
+
l = blockIdx.x * blockDim.x + threadIdx.x;
|
| 134 |
+
h = blockIdx.y * blockDim.y + threadIdx.y;
|
| 135 |
+
if (l >= L || h >= num_heads)
|
| 136 |
+
return;
|
| 137 |
+
|
| 138 |
+
int m_id = m_id_d[l * num_heads + h];
|
| 139 |
+
check_within_image_coordinates(m_id, N, H, W, win_x, win_y);
|
| 140 |
+
m_id_d[l * num_heads + h] = m_id;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
template <typename scalar_t>
|
| 144 |
+
__global__ void match_attention_l1_norm_forward_k(
|
| 145 |
+
const scalar_t *__restrict__ q_d,
|
| 146 |
+
const scalar_t *__restrict__ k_d,
|
| 147 |
+
scalar_t *__restrict__ attn_d,
|
| 148 |
+
const int *__restrict__ m_id_d,
|
| 149 |
+
const int *__restrict__ offset_d,
|
| 150 |
+
const int L, const int N, const int H, const int W,
|
| 151 |
+
const int C, const int num_heads, const int key_dim,
|
| 152 |
+
const int attn_num, const int attn_numel,
|
| 153 |
+
const bool swap_xy)
|
| 154 |
+
{
|
| 155 |
+
int l, k;
|
| 156 |
+
if (swap_xy)
|
| 157 |
+
{
|
| 158 |
+
l = blockIdx.x * blockDim.x + threadIdx.x;
|
| 159 |
+
k = blockIdx.y * blockDim.y + threadIdx.y;
|
| 160 |
+
}
|
| 161 |
+
else
|
| 162 |
+
{
|
| 163 |
+
k = blockIdx.x * blockDim.x + threadIdx.x;
|
| 164 |
+
l = blockIdx.y * blockDim.y + threadIdx.y;
|
| 165 |
+
}
|
| 166 |
+
if (l >= L || k >= num_heads*attn_num)
|
| 167 |
+
return;
|
| 168 |
+
|
| 169 |
+
constexpr int vec_size = sizeof(float4) / sizeof(scalar_t);
|
| 170 |
+
const int h = k / attn_num;
|
| 171 |
+
const int attn_id = k % attn_num;
|
| 172 |
+
const int base_id = l*num_heads + h;
|
| 173 |
+
const int base_attn_id = base_id*attn_num;
|
| 174 |
+
const int key_id = m_id_d[base_id] + offset_d[attn_id];
|
| 175 |
+
|
| 176 |
+
const int q_base = l * C;
|
| 177 |
+
const int k_base = key_id * C;
|
| 178 |
+
const int c_start = h * key_dim / vec_size;
|
| 179 |
+
const int c_end = c_start + key_dim / vec_size;
|
| 180 |
+
|
| 181 |
+
const float4* q_val_vec = reinterpret_cast<const float4*>(q_d + q_base);
|
| 182 |
+
const float4* k_val_vec = reinterpret_cast<const float4*>(k_d + k_base);
|
| 183 |
+
|
| 184 |
+
float diff_sum = 0.0f;
|
| 185 |
+
|
| 186 |
+
for (int c = c_start; c < c_end; ++c) {
|
| 187 |
+
float4 q_val_f4 = __ldg(&q_val_vec[c]);
|
| 188 |
+
float4 k_val_f4 = __ldg(&k_val_vec[c]);
|
| 189 |
+
|
| 190 |
+
if (vec_size == 4) { // float32
|
| 191 |
+
diff_sum += fabsf(q_val_f4.x - k_val_f4.x) +
|
| 192 |
+
fabsf(q_val_f4.y - k_val_f4.y) +
|
| 193 |
+
fabsf(q_val_f4.z - k_val_f4.z) +
|
| 194 |
+
fabsf(q_val_f4.w - k_val_f4.w);
|
| 195 |
+
} else { // bf16/fp16 (8 elements)
|
| 196 |
+
if (std::is_same<scalar_t, at::Half>::value) {
|
| 197 |
+
const half2* q_val_h2 = reinterpret_cast<const half2*>(&q_val_f4);
|
| 198 |
+
const half2* k_val_h2 = reinterpret_cast<const half2*>(&k_val_f4);
|
| 199 |
+
#pragma unroll
|
| 200 |
+
for (int i = 0; i < 4; ++i) {
|
| 201 |
+
half2 q_h2 = q_val_h2[i];
|
| 202 |
+
half2 k_h2 = k_val_h2[i];
|
| 203 |
+
half2 diff_h2 = __habs2(__hsub2(q_h2, k_h2));
|
| 204 |
+
diff_sum += __half2float(diff_h2.x) + __half2float(diff_h2.y);
|
| 205 |
+
}
|
| 206 |
+
} else { // bf16
|
| 207 |
+
const __nv_bfloat162* q_val_bf2 = reinterpret_cast<const __nv_bfloat162*>(&q_val_f4);
|
| 208 |
+
const __nv_bfloat162* k_val_bf2 = reinterpret_cast<const __nv_bfloat162*>(&k_val_f4);
|
| 209 |
+
#pragma unroll
|
| 210 |
+
for (int i = 0; i < 4; ++i) {
|
| 211 |
+
__nv_bfloat162 q_bf2 = q_val_bf2[i];
|
| 212 |
+
__nv_bfloat162 k_bf2 = k_val_bf2[i];
|
| 213 |
+
__nv_bfloat162 diff_bf2 = __habs2(__hsub2(q_bf2, k_bf2));
|
| 214 |
+
diff_sum += __bfloat162float(diff_bf2.x) + __bfloat162float(diff_bf2.y);
|
| 215 |
+
}
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
}
|
| 219 |
+
attn_d[base_attn_id + attn_id] = static_cast<scalar_t>(-diff_sum);
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
template <typename scalar_t>
|
| 223 |
+
__global__ void
|
| 224 |
+
match_attention_dot_product_forward_k(const scalar_t *const q_d, const scalar_t *const k_d, scalar_t *const attn_d, const int *const m_id_d, const int* const offset_d, const int L, const int N, const int H, const int W, const int C, const int num_heads, const int key_dim, const int attn_num, const int attn_numel, const bool swap_xy)
|
| 225 |
+
{
|
| 226 |
+
int l, k;
|
| 227 |
+
if (swap_xy)
|
| 228 |
+
{
|
| 229 |
+
l = blockIdx.x * blockDim.x + threadIdx.x;
|
| 230 |
+
k = blockIdx.y * blockDim.y + threadIdx.y;
|
| 231 |
+
}
|
| 232 |
+
else
|
| 233 |
+
{
|
| 234 |
+
k = blockIdx.x * blockDim.x + threadIdx.x;
|
| 235 |
+
l = blockIdx.y * blockDim.y + threadIdx.y;
|
| 236 |
+
}
|
| 237 |
+
if (l >= L || k >= num_heads*attn_num)
|
| 238 |
+
return;
|
| 239 |
+
|
| 240 |
+
int h = k / attn_num;
|
| 241 |
+
int attn_id = k % attn_num;
|
| 242 |
+
int base_id = l*num_heads + h;
|
| 243 |
+
int base_attn_id = base_id*attn_num;
|
| 244 |
+
int key_id = m_id_d[base_id] + offset_d[attn_id];
|
| 245 |
+
scalar_t diff_sum = 0;
|
| 246 |
+
for (int c = h * key_dim; c < (h + 1) * key_dim; ++c)
|
| 247 |
+
{
|
| 248 |
+
diff_sum += q_d[l * C + c] * k_d[key_id * C + c];
|
| 249 |
+
}
|
| 250 |
+
attn_d[base_attn_id + attn_id] = diff_sum;
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
template <typename scalar_t>
|
| 254 |
+
__global__ void scale_attention_k(scalar_t* attn_d, const scalar_t scale, const int total_size)
|
| 255 |
+
{
|
| 256 |
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 257 |
+
if (idx >= total_size)
|
| 258 |
+
return;
|
| 259 |
+
attn_d[idx] = attn_d[idx] * scale;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
template <typename T> struct VecType { using Type = T; };
|
| 263 |
+
template <> struct VecType<float> { using Type = float4; };
|
| 264 |
+
template <> struct VecType<__half> { using Type = float2; };
|
| 265 |
+
template <> struct VecType<__nv_bfloat16> { using Type = float2; };
|
| 266 |
+
|
| 267 |
+
template <typename scalar_t>
|
| 268 |
+
__device__ __inline__ typename VecType<scalar_t>::Type load_vec(const scalar_t* addr) {
|
| 269 |
+
return *reinterpret_cast<const typename VecType<scalar_t>::Type*>(addr);
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
template <typename scalar_t>
|
| 273 |
+
__device__ __inline__ void store_vec(scalar_t* addr, typename VecType<scalar_t>::Type val) {
|
| 274 |
+
*reinterpret_cast<typename VecType<scalar_t>::Type*>(addr) = val;
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
template <int WIN_SIZE, typename scalar_t>
|
| 278 |
+
__device__ __forceinline__ void load_window(scalar_t* window, const scalar_t* src) {
|
| 279 |
+
constexpr int VEC_ELEMS = sizeof(typename VecType<scalar_t>::Type) / sizeof(scalar_t);
|
| 280 |
+
constexpr int VEC_COUNT = WIN_SIZE / VEC_ELEMS;
|
| 281 |
+
using vec_t = typename VecType<scalar_t>::Type;
|
| 282 |
+
|
| 283 |
+
#pragma unroll 4
|
| 284 |
+
for (int i = 0; i < VEC_COUNT; ++i) {
|
| 285 |
+
vec_t vec = load_vec<scalar_t>(src + i * VEC_ELEMS);
|
| 286 |
+
store_vec<scalar_t>(window + i * VEC_ELEMS, vec);
|
| 287 |
+
}
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
template <int WIN_SIZE, typename scalar_t>
|
| 291 |
+
__device__ __forceinline__ void store_window(scalar_t* dst, const scalar_t* window) {
|
| 292 |
+
constexpr int VEC_ELEMS = sizeof(typename VecType<scalar_t>::Type) / sizeof(scalar_t);
|
| 293 |
+
constexpr int VEC_COUNT = WIN_SIZE / VEC_ELEMS;
|
| 294 |
+
using vec_t = typename VecType<scalar_t>::Type;
|
| 295 |
+
|
| 296 |
+
#pragma unroll 4
|
| 297 |
+
for (int i = 0; i < VEC_COUNT; ++i) {
|
| 298 |
+
vec_t vec = load_vec<scalar_t>(window + i * VEC_ELEMS);
|
| 299 |
+
store_vec<scalar_t>(dst + i * VEC_ELEMS, vec);
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
template <int WIN_SIZE, int SUB_WIN_SIZE, typename scalar_t>
|
| 304 |
+
__global__ void
|
| 305 |
+
bilinear_softmax_forward_k(scalar_t* const __restrict__ attn_d,
|
| 306 |
+
scalar_t* const __restrict__ attn_out_d,
|
| 307 |
+
scalar_t* const __restrict__ attn_sum_d,
|
| 308 |
+
const scalar_t* const __restrict__ bilinear_weight_d,
|
| 309 |
+
const int* const __restrict__ select_index_d,
|
| 310 |
+
int L, const int num_heads, const int h_attn_num,
|
| 311 |
+
const int attn_num)
|
| 312 |
+
{
|
| 313 |
+
constexpr int VEC_ELEMS = sizeof(typename VecType<scalar_t>::Type) / sizeof(scalar_t);
|
| 314 |
+
static_assert(WIN_SIZE % VEC_ELEMS == 0, "WIN_SIZE must be divisible by vector elements");
|
| 315 |
+
using acc_t = float;
|
| 316 |
+
|
| 317 |
+
int l = blockIdx.x * blockDim.x + threadIdx.x;
|
| 318 |
+
int h = blockIdx.y * blockDim.y + threadIdx.y;
|
| 319 |
+
if (l >= L || h >= num_heads)
|
| 320 |
+
return;
|
| 321 |
+
|
| 322 |
+
const int base_attn_id = l * h_attn_num + h * attn_num;
|
| 323 |
+
const int base_sum_idx = l * (num_heads * 4) + h * 4;
|
| 324 |
+
|
| 325 |
+
scalar_t window[WIN_SIZE];
|
| 326 |
+
load_window<WIN_SIZE>(window, attn_d + base_attn_id);
|
| 327 |
+
|
| 328 |
+
acc_t attn_max = -FLT_MAX;
|
| 329 |
+
#pragma unroll 4
|
| 330 |
+
for (int k = 0; k < WIN_SIZE; ++k) {
|
| 331 |
+
if (static_cast<acc_t>(window[k]) > attn_max) {
|
| 332 |
+
attn_max = static_cast<acc_t>(window[k]);
|
| 333 |
+
}
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
#pragma unroll 4
|
| 337 |
+
for (int k = 0; k < WIN_SIZE; ++k) {
|
| 338 |
+
window[k] = static_cast<scalar_t>(expf(static_cast<acc_t>(window[k]) - attn_max));
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
scalar_t window_out[WIN_SIZE] = {0};
|
| 342 |
+
|
| 343 |
+
for (int b = 0; b < 4; ++b) {
|
| 344 |
+
acc_t block_sum = 0.0f;
|
| 345 |
+
const int* block_idx = select_index_d + b * SUB_WIN_SIZE;
|
| 346 |
+
|
| 347 |
+
#pragma unroll 4
|
| 348 |
+
for (int k = 0; k < SUB_WIN_SIZE; ++k) {
|
| 349 |
+
block_sum += static_cast<acc_t>(window[block_idx[k]]);
|
| 350 |
+
}
|
| 351 |
+
block_sum = fmaxf(block_sum, FLT_EPSILON);
|
| 352 |
+
attn_sum_d[base_sum_idx + b] = static_cast<scalar_t>(block_sum);
|
| 353 |
+
|
| 354 |
+
const scalar_t weight = bilinear_weight_d[base_sum_idx + b];
|
| 355 |
+
const scalar_t scale = static_cast<scalar_t>(static_cast<acc_t>(weight) / block_sum);
|
| 356 |
+
|
| 357 |
+
#pragma unroll 4
|
| 358 |
+
for (int k = 0; k < SUB_WIN_SIZE; ++k) {
|
| 359 |
+
const int idx = block_idx[k];
|
| 360 |
+
window_out[idx] = window_out[idx] + window[idx] * scale;
|
| 361 |
+
}
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
// write back to global memory
|
| 365 |
+
store_window<WIN_SIZE>(attn_out_d + base_attn_id, window_out);
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
template <typename scalar_t>
|
| 369 |
+
__global__ void
|
| 370 |
+
bilinear_softmax_forward_general_k(scalar_t* const __restrict__ attn_d,
|
| 371 |
+
scalar_t* const __restrict__ attn_out_d,
|
| 372 |
+
scalar_t* const __restrict__ attn_sum_d,
|
| 373 |
+
const scalar_t* const __restrict__ bilinear_weight_d,
|
| 374 |
+
const int* const __restrict__ select_index_d,
|
| 375 |
+
int L, const int num_heads, const int h_attn_num,
|
| 376 |
+
const int attn_num, const int attn_num_sub)
|
| 377 |
+
{
|
| 378 |
+
int l, h;
|
| 379 |
+
l = blockIdx.x * blockDim.x + threadIdx.x;
|
| 380 |
+
h = blockIdx.y * blockDim.y + threadIdx.y;
|
| 381 |
+
if (l >= L || h >= num_heads)
|
| 382 |
+
return;
|
| 383 |
+
|
| 384 |
+
scalar_t attn_max = -FLT_MAX;
|
| 385 |
+
int base_attn_id = l * h_attn_num + h * attn_num;
|
| 386 |
+
for (int k = 0; k < attn_num; ++k)
|
| 387 |
+
{
|
| 388 |
+
scalar_t attn_val = attn_d[base_attn_id + k];
|
| 389 |
+
if (attn_val > attn_max) {
|
| 390 |
+
attn_max = attn_val;
|
| 391 |
+
}
|
| 392 |
+
}
|
| 393 |
+
__syncthreads();
|
| 394 |
+
|
| 395 |
+
for (int k = 0; k < attn_num; ++k)
|
| 396 |
+
{
|
| 397 |
+
attn_d[base_attn_id + k] = expf(attn_d[base_attn_id + k] - attn_max);
|
| 398 |
+
}
|
| 399 |
+
__syncthreads();
|
| 400 |
+
|
| 401 |
+
for (int b = 0; b < 4; ++b)
|
| 402 |
+
{
|
| 403 |
+
scalar_t attn_sum = 0;
|
| 404 |
+
for (int k = 0; k < attn_num_sub; ++k)
|
| 405 |
+
{
|
| 406 |
+
attn_sum += attn_d[base_attn_id + select_index_d[b*attn_num_sub + k]];
|
| 407 |
+
}
|
| 408 |
+
attn_sum = fmaxf(attn_sum, FLT_EPSILON);
|
| 409 |
+
attn_sum_d[l*(num_heads*4) + h*4 + b] = attn_sum; // save for backward
|
| 410 |
+
|
| 411 |
+
scalar_t weight = bilinear_weight_d[l*num_heads*4 + h*4 + b];
|
| 412 |
+
for (int k = 0; k < attn_num_sub; ++k)
|
| 413 |
+
{
|
| 414 |
+
int select_index = select_index_d[b*attn_num_sub + k];
|
| 415 |
+
attn_out_d[base_attn_id + select_index] +=
|
| 416 |
+
attn_d[base_attn_id + select_index] / attn_sum * weight; // no write conflict
|
| 417 |
+
}
|
| 418 |
+
}
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
template <typename scalar_t>
|
| 422 |
+
__global__ void attention_aggregate_forward_k(
|
| 423 |
+
const scalar_t *__restrict__ v_d,
|
| 424 |
+
scalar_t *__restrict__ out_d,
|
| 425 |
+
const scalar_t *__restrict__ attn_d,
|
| 426 |
+
const int *__restrict__ m_id_d,
|
| 427 |
+
const int* __restrict__ offset_d,
|
| 428 |
+
const int L, const int C, const int num_heads,
|
| 429 |
+
const int key_dim, const int attn_num,
|
| 430 |
+
const bool swap_xy)
|
| 431 |
+
{
|
| 432 |
+
int c, l;
|
| 433 |
+
if (swap_xy)
|
| 434 |
+
{
|
| 435 |
+
l = blockIdx.x * blockDim.x + threadIdx.x;
|
| 436 |
+
c = blockIdx.y * blockDim.y + threadIdx.y;
|
| 437 |
+
}
|
| 438 |
+
else
|
| 439 |
+
{
|
| 440 |
+
c = blockIdx.x * blockDim.x + threadIdx.x;
|
| 441 |
+
l = blockIdx.y * blockDim.y + threadIdx.y;
|
| 442 |
+
}
|
| 443 |
+
if (l >= L || c >= C)
|
| 444 |
+
return;
|
| 445 |
+
|
| 446 |
+
const int h = c / key_dim;
|
| 447 |
+
const int base_id = l*num_heads + h;
|
| 448 |
+
const int base_attn_id = base_id*attn_num;
|
| 449 |
+
const int m_id = m_id_d[base_id];
|
| 450 |
+
float out_sum = 0;
|
| 451 |
+
for (int k = 0; k < attn_num; ++k)
|
| 452 |
+
{
|
| 453 |
+
int key_id = m_id + offset_d[k];
|
| 454 |
+
out_sum += static_cast<float>(attn_d[base_attn_id + k]) *
|
| 455 |
+
static_cast<float>(v_d[key_id * C + c]);
|
| 456 |
+
}
|
| 457 |
+
out_d[l * C + c] = static_cast<scalar_t>(out_sum);
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
// Main fused forward function
|
| 461 |
+
void match_former_fused_forward(
|
| 462 |
+
at::Tensor max_offset,
|
| 463 |
+
at::Tensor q,
|
| 464 |
+
at::Tensor k,
|
| 465 |
+
at::Tensor v,
|
| 466 |
+
at::Tensor output,
|
| 467 |
+
at::Tensor attn_out,
|
| 468 |
+
const int H,
|
| 469 |
+
const int W,
|
| 470 |
+
const std::vector<int64_t>& win_r,
|
| 471 |
+
const int attn_num,
|
| 472 |
+
const std::string& attn_type,
|
| 473 |
+
const float scale)
|
| 474 |
+
{
|
| 475 |
+
const int B = q.size(0);
|
| 476 |
+
const int N = q.size(1);
|
| 477 |
+
const int C = q.size(2);
|
| 478 |
+
const int h = max_offset.size(2);
|
| 479 |
+
const int key_dim = C / h;
|
| 480 |
+
const int L = B * N;
|
| 481 |
+
const int Lh = L * h;
|
| 482 |
+
const int attn_numel = L * h * attn_num;
|
| 483 |
+
const int win_x = win_r[0];
|
| 484 |
+
const int win_y = win_r[1];
|
| 485 |
+
assert(attn_num == (2*win_r[0]+2)*(2*win_r[1]+2));
|
| 486 |
+
const bool swap_xy_match = (h * attn_num < 32);
|
| 487 |
+
const bool swap_xy_agg = (C < 32);
|
| 488 |
+
const int attn_num_sub = (2*win_r[0] + 1)*(2*win_r[1] + 1);
|
| 489 |
+
const int h_attn_num = h * attn_num;
|
| 490 |
+
|
| 491 |
+
// Create temporary tensors
|
| 492 |
+
auto m_id = at::zeros({B, N, h}, at::TensorOptions().dtype(at::kInt).device(max_offset.device()));
|
| 493 |
+
auto bilinear_weight = at::zeros({B, N, h, 4}, max_offset.options());
|
| 494 |
+
auto attn = at::zeros({B, N, h, attn_num}, q.options());
|
| 495 |
+
auto attn_sum = at::zeros({B, N, h, 4}, q.options());
|
| 496 |
+
|
| 497 |
+
// Create offset array for window
|
| 498 |
+
int *offset_d;
|
| 499 |
+
cudaMalloc(&offset_d, sizeof(int) * attn_num);
|
| 500 |
+
int *offset_h = new int[attn_num];
|
| 501 |
+
int num = 0;
|
| 502 |
+
for (int y = -win_y; y <= (win_y + 1); ++y)
|
| 503 |
+
for (int x = -win_x; x <= (win_x + 1); ++x)
|
| 504 |
+
{
|
| 505 |
+
offset_h[num++] = y * W + x;
|
| 506 |
+
}
|
| 507 |
+
cudaMemcpy(offset_d, offset_h, sizeof(int) * attn_num, cudaMemcpyHostToDevice);
|
| 508 |
+
delete[] offset_h;
|
| 509 |
+
|
| 510 |
+
// Create select_index array for bilinear softmax
|
| 511 |
+
int *select_index_d;
|
| 512 |
+
cudaMalloc(&select_index_d, sizeof(int)*4*attn_num_sub);
|
| 513 |
+
int *select_index_h = new int[4*attn_num_sub];
|
| 514 |
+
int win_W = 2*(win_r[0]+1);
|
| 515 |
+
int delta_x[4] = {0, 1, 0, 1};
|
| 516 |
+
int delta_y[4] = {0, 0, 1, 1};
|
| 517 |
+
num = 0;
|
| 518 |
+
for (int b = 0; b < 4; ++b) {
|
| 519 |
+
int d_x = delta_x[b];
|
| 520 |
+
int d_y = delta_y[b];
|
| 521 |
+
for (int y = d_y; y <= 2*win_r[1] + d_y; ++y)
|
| 522 |
+
for (int x = d_x; x <= 2*win_r[0] + d_x; ++x)
|
| 523 |
+
{
|
| 524 |
+
select_index_h[num++] = y * win_W + x;
|
| 525 |
+
}
|
| 526 |
+
}
|
| 527 |
+
cudaMemcpy(select_index_d, select_index_h, sizeof(int)*attn_num_sub*4, cudaMemcpyHostToDevice);
|
| 528 |
+
delete[] select_index_h;
|
| 529 |
+
|
| 530 |
+
// Step 1: Clip offset to id
|
| 531 |
+
{
|
| 532 |
+
int grid = (Lh + 512 - 1) / 512;
|
| 533 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, max_offset.scalar_type(), "clip_offset_to_id_k", ([&] {
|
| 534 |
+
clip_offset_to_id_k<scalar_t><<<grid, 512>>>(max_offset.data_ptr<scalar_t>(), m_id.data_ptr<int>(), Lh, h, N, H, W);
|
| 535 |
+
}));
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
// Step 2: Compute bilinear weights
|
| 539 |
+
{
|
| 540 |
+
int grid = (Lh + 512 - 1) / 512;
|
| 541 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, max_offset.scalar_type(), "attn_weight_bilinear_forward_k", ([&] {
|
| 542 |
+
attn_weight_bilinear_forward_k<scalar_t><<<grid, 512>>>(max_offset.data_ptr<scalar_t>(), bilinear_weight.data_ptr<scalar_t>(), Lh);
|
| 543 |
+
}));
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
// Step 3: Check max id bounds
|
| 547 |
+
{
|
| 548 |
+
dim3 m_blocks(8, 128);
|
| 549 |
+
dim3 grids((L + m_blocks.x - 1) / m_blocks.x, (h + m_blocks.y - 1) / m_blocks.y);
|
| 550 |
+
check_max_id_k<<<grids, m_blocks>>>(m_id.data_ptr<int>(), L, N, H, W, h, win_x, win_y);
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
// Step 4: Compute attention
|
| 554 |
+
{
|
| 555 |
+
dim3 m_blocks(8, 128);
|
| 556 |
+
dim3 grids((h*attn_num + m_blocks.x - 1) / m_blocks.x, (L + m_blocks.y - 1) / m_blocks.y);
|
| 557 |
+
if (swap_xy_match)
|
| 558 |
+
grids = dim3((L + m_blocks.x - 1) / m_blocks.x, (h*attn_num + m_blocks.y - 1) / m_blocks.y);
|
| 559 |
+
|
| 560 |
+
if (attn_type == "dot_product") {
|
| 561 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(q.scalar_type(), "match_attention_dot_product_forward_k", ([&] {
|
| 562 |
+
match_attention_dot_product_forward_k<scalar_t><<<grids, m_blocks>>>(q.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), attn.data_ptr<scalar_t>(), m_id.data_ptr<int>(), offset_d, L, N, H, W, C, h, key_dim, attn_num, attn_numel, swap_xy_match);
|
| 563 |
+
}));
|
| 564 |
+
} else if (attn_type == "l1_norm") {
|
| 565 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "match_attention_l1_norm_forward_k", ([&] {
|
| 566 |
+
match_attention_l1_norm_forward_k<scalar_t><<<grids, m_blocks>>>(q.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), attn.data_ptr<scalar_t>(), m_id.data_ptr<int>(), offset_d, L, N, H, W, C, h, key_dim, attn_num, attn_numel, swap_xy_match);
|
| 567 |
+
}));
|
| 568 |
+
}
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
// Step 5: Scale attention
|
| 572 |
+
{
|
| 573 |
+
int grid = (attn_numel + 512 - 1) / 512;
|
| 574 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "scale_attention_k", ([&] {
|
| 575 |
+
scale_attention_k<scalar_t><<<grid, 512>>>(attn.data_ptr<scalar_t>(), static_cast<scalar_t>(scale), attn_numel);
|
| 576 |
+
}));
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
// Step 6: Bilinear softmax
|
| 580 |
+
{
|
| 581 |
+
dim3 m_blocks = (attn_num == 16) ? dim3(128, 4) : dim3(32, 4);
|
| 582 |
+
dim3 grids((L + m_blocks.x - 1) / m_blocks.x, (h + m_blocks.y - 1) / m_blocks.y);
|
| 583 |
+
|
| 584 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attn.scalar_type(), "bilinear_softmax_forward", [&] {
|
| 585 |
+
if (attn_num == 16 && attn_num_sub == 9) {
|
| 586 |
+
bilinear_softmax_forward_k<16, 9><<<grids, m_blocks>>>(
|
| 587 |
+
attn.data_ptr<scalar_t>(),
|
| 588 |
+
attn_out.data_ptr<scalar_t>(),
|
| 589 |
+
attn_sum.data_ptr<scalar_t>(),
|
| 590 |
+
bilinear_weight.data_ptr<scalar_t>(),
|
| 591 |
+
select_index_d, L, h, h_attn_num, attn_num
|
| 592 |
+
);
|
| 593 |
+
} else if (attn_num == 36 && attn_num_sub == 25) {
|
| 594 |
+
bilinear_softmax_forward_k<36, 25><<<grids, m_blocks>>>(
|
| 595 |
+
attn.data_ptr<scalar_t>(),
|
| 596 |
+
attn_out.data_ptr<scalar_t>(),
|
| 597 |
+
attn_sum.data_ptr<scalar_t>(),
|
| 598 |
+
bilinear_weight.data_ptr<scalar_t>(),
|
| 599 |
+
select_index_d, L, h, h_attn_num, attn_num
|
| 600 |
+
);
|
| 601 |
+
} else {
|
| 602 |
+
bilinear_softmax_forward_general_k<<<grids, m_blocks>>>(
|
| 603 |
+
attn.data_ptr<scalar_t>(),
|
| 604 |
+
attn_out.data_ptr<scalar_t>(),
|
| 605 |
+
attn_sum.data_ptr<scalar_t>(),
|
| 606 |
+
bilinear_weight.data_ptr<scalar_t>(),
|
| 607 |
+
select_index_d, L, h, h_attn_num, attn_num, attn_num_sub
|
| 608 |
+
);
|
| 609 |
+
}
|
| 610 |
+
});
|
| 611 |
+
}
|
| 612 |
+
|
| 613 |
+
// Step 7: Attention aggregation
|
| 614 |
+
{
|
| 615 |
+
dim3 m_blocks = (attn_num == 16) ? dim3(8, 128) : dim3(8, 32);
|
| 616 |
+
dim3 grids((C + m_blocks.x - 1) / m_blocks.x, (L + m_blocks.y - 1) / m_blocks.y);
|
| 617 |
+
if (swap_xy_agg)
|
| 618 |
+
grids = dim3((L + m_blocks.x - 1) / m_blocks.x, (C + m_blocks.y - 1) / m_blocks.y);
|
| 619 |
+
|
| 620 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, v.scalar_type(), "attention_aggregate_forward_k", ([&] {
|
| 621 |
+
attention_aggregate_forward_k<scalar_t><<<grids, m_blocks>>>(v.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(), attn_out.data_ptr<scalar_t>(), m_id.data_ptr<int>(), offset_d, L, C, h, key_dim, attn_num, swap_xy_agg);
|
| 622 |
+
}));
|
| 623 |
+
}
|
| 624 |
+
|
| 625 |
+
// Cleanup
|
| 626 |
+
cudaFree(offset_d);
|
| 627 |
+
cudaFree(select_index_d);
|
| 628 |
+
}
|
models/src/match_former_fused_forward.hpp
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef _MATCH_FORMER_FUSED_FORWARD_HPP_
|
| 2 |
+
#define _MATCH_FORMER_FUSED_FORWARD_HPP_
|
| 3 |
+
|
| 4 |
+
#include <vector>
|
| 5 |
+
#include <string>
|
| 6 |
+
|
| 7 |
+
// Fused forward function that combines all match former operations
|
| 8 |
+
void match_former_fused_forward(
|
| 9 |
+
at::Tensor max_offset,
|
| 10 |
+
at::Tensor q,
|
| 11 |
+
at::Tensor k,
|
| 12 |
+
at::Tensor v,
|
| 13 |
+
at::Tensor output,
|
| 14 |
+
at::Tensor attn_out,
|
| 15 |
+
const int H,
|
| 16 |
+
const int W,
|
| 17 |
+
const std::vector<int64_t>& win_r,
|
| 18 |
+
const int attn_num,
|
| 19 |
+
const std::string& attn_type,
|
| 20 |
+
const float scale);
|
| 21 |
+
|
| 22 |
+
#endif
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
imageio==2.9.0
|
| 2 |
+
imageio-ffmpeg==0.4.9
|
| 3 |
+
matplotlib==3.8.4
|
| 4 |
+
opencv-python==4.9.0.80
|
| 5 |
+
pillow==10.2.0
|
| 6 |
+
scikit-image==0.20.0
|
| 7 |
+
scipy==1.9.1
|
| 8 |
+
tensorboard==2.17.0
|
| 9 |
+
setuptools==59.5.0
|
| 10 |
+
psutil==6.0.0
|
| 11 |
+
joblib==1.4.2
|
| 12 |
+
numpy==1.24.4
|
| 13 |
+
tqdm==4.66.2
|
| 14 |
+
timm==0.6.11
|
utils/file_io.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import
|
| 2 |
+
from __future__ import division
|
| 3 |
+
from __future__ import print_function
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
def write_pfm(file, image, scale=1):
|
| 9 |
+
file = open(file, 'wb')
|
| 10 |
+
|
| 11 |
+
color = None
|
| 12 |
+
|
| 13 |
+
if image.dtype.name != 'float32':
|
| 14 |
+
raise Exception('Image dtype must be float32.')
|
| 15 |
+
|
| 16 |
+
image = np.flipud(image)
|
| 17 |
+
|
| 18 |
+
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
| 19 |
+
color = True
|
| 20 |
+
elif len(image.shape) == 2 or len(
|
| 21 |
+
image.shape) == 3 and image.shape[2] == 1: # greyscale
|
| 22 |
+
color = False
|
| 23 |
+
else:
|
| 24 |
+
raise Exception(
|
| 25 |
+
'Image must have H x W x 3, H x W x 1 or H x W dimensions.')
|
| 26 |
+
|
| 27 |
+
file.write(b'PF\n' if color else b'Pf\n')
|
| 28 |
+
file.write(b'%d %d\n' % (image.shape[1], image.shape[0]))
|
| 29 |
+
|
| 30 |
+
endian = image.dtype.byteorder
|
| 31 |
+
|
| 32 |
+
if endian == '<' or endian == '=' and sys.byteorder == 'little':
|
| 33 |
+
scale = -scale
|
| 34 |
+
|
| 35 |
+
file.write(b'%f\n' % scale)
|
| 36 |
+
|
| 37 |
+
image.tofile(file)
|
utils/utils.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class InputPadder:
|
| 7 |
+
""" Pads images such that dimensions are divisible by padding_factor """
|
| 8 |
+
|
| 9 |
+
def __init__(self, dims, mode='top_right', padding_factor=32):
|
| 10 |
+
self.ht, self.wd = dims[-2:]
|
| 11 |
+
pad_ht = (((self.ht // padding_factor) + 1) * padding_factor - self.ht) % padding_factor
|
| 12 |
+
pad_wd = (((self.wd // padding_factor) + 1) * padding_factor - self.wd) % padding_factor
|
| 13 |
+
if mode == 'sintel':
|
| 14 |
+
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
|
| 15 |
+
elif mode == 'top_right':
|
| 16 |
+
self._pad = [0, pad_wd, pad_ht, 0]
|
| 17 |
+
elif mode == 'bottom_right':
|
| 18 |
+
self._pad = [0, pad_wd, 0, pad_ht]
|
| 19 |
+
else:
|
| 20 |
+
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
|
| 21 |
+
|
| 22 |
+
def pad(self, *inputs):
|
| 23 |
+
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
| 24 |
+
|
| 25 |
+
def unpad(self, x):
|
| 26 |
+
ht, wd = x.shape[-2:]
|
| 27 |
+
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
| 28 |
+
return x[..., c[0]:c[1], c[2]:c[3]]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def init_coords(ref):
|
| 32 |
+
B, H, W, C = ref.shape
|
| 33 |
+
|
| 34 |
+
coords = torch.meshgrid(torch.arange(H, device=ref.device, dtype=ref.dtype), torch.arange(W, device=ref.device, dtype=ref.dtype), indexing='ij')
|
| 35 |
+
coords = torch.stack(coords[::-1], dim=-1)
|
| 36 |
+
return coords[None].repeat(B, 1, 1, 1).to(ref.device) # [B, H, W, 2]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def bilinear_sample_by_offset(tgt, offset): # tgt [B, _, H, W], offset [B, H, W, 2]
|
| 40 |
+
_, _, H, W = tgt.shape
|
| 41 |
+
|
| 42 |
+
xgrid, ygrid = offset.split([1, 1], dim=-1)
|
| 43 |
+
xgrid = 2*xgrid/(W-1) - 1
|
| 44 |
+
ygrid = 2*ygrid/(H-1) - 1
|
| 45 |
+
grid = torch.cat([xgrid, ygrid], dim=-1)
|
| 46 |
+
|
| 47 |
+
tgt_to_ref = F.grid_sample(tgt, grid, mode='bilinear', align_corners=True)
|
| 48 |
+
return tgt_to_ref
|
| 49 |
+
|
| 50 |
+
def calc_noc_mask(field, A=2):
|
| 51 |
+
offset = field + init_coords(field) # [B, H, W, 2]
|
| 52 |
+
field_ref_, field_tgt_ = field.chunk(2, dim=0)
|
| 53 |
+
field_ref = torch.cat((field_ref_, field_tgt_), dim=0) # order
|
| 54 |
+
field_tgt = torch.cat((field_tgt_, field_ref_), dim=0) # reverse order
|
| 55 |
+
field_tgt_to_ref = bilinear_sample_by_offset(field_tgt.permute(0, 3, 1, 2).contiguous(), offset).permute(0, 2, 3, 1).contiguous()
|
| 56 |
+
field_diff = torch.abs(field_ref + field_tgt_to_ref).sum(dim=-1) # ref and tgt flow has different sign
|
| 57 |
+
noc_mask = (field_diff < A).to(field_diff.dtype)
|
| 58 |
+
return noc_mask
|