Tingman commited on
Commit
0940df6
·
1 Parent(s): 9d66414

code release

Browse files

Signed-off-by: tingmany <tmyann@outlook.com>

.gitattributes CHANGED
@@ -33,4 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- staircase_q_left.png filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: green
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
- app_file: app.py
9
  pinned: false
10
  license: gpl-3.0
11
  ---
 
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
+ app_file: gradio_app.py
9
  pinned: false
10
  license: gpl-3.0
11
  ---
dataloader/stereo/transforms.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+
6
+
7
+ class Compose(object):
8
+ def __init__(self, transforms):
9
+ self.transforms = transforms
10
+
11
+ def __call__(self, sample):
12
+ for t in self.transforms:
13
+ sample = t(sample)
14
+ return sample
15
+
16
+
17
+ class ToTensor(object):
18
+ """Convert numpy array to torch tensor"""
19
+
20
+ def __init__(self, no_normalize=False):
21
+ self.no_normalize = no_normalize
22
+
23
+ def __call__(self, sample):
24
+ left = np.transpose(sample['left'], (2, 0, 1)) # [3, H, W]
25
+ if self.no_normalize:
26
+ sample['left'] = torch.from_numpy(left)
27
+ else:
28
+ sample['left'] = torch.from_numpy(left) / 255.
29
+ right = np.transpose(sample['right'], (2, 0, 1))
30
+
31
+ if self.no_normalize:
32
+ sample['right'] = torch.from_numpy(right)
33
+ else:
34
+ sample['right'] = torch.from_numpy(right) / 255.
35
+
36
+ if 'disp' in sample.keys():
37
+ disp = sample['disp'] # [H, W]
38
+ sample['disp'] = torch.from_numpy(disp)
39
+ if 'disp_r' in sample.keys():
40
+ disp_r = sample['disp_r'] # [H, W]
41
+ sample['disp_r'] = torch.from_numpy(disp_r)
42
+
43
+ if 'valid' in sample.keys():
44
+ valid = sample['valid'] # [H, W]
45
+ sample['valid'] = torch.from_numpy(valid)
46
+
47
+ return sample
48
+
49
+
50
+ class Resize(object):
51
+ def __init__(self,
52
+ scale_x=1,
53
+ scale_y=1,
54
+ nearest_interp=True, # for sparse gt
55
+ ):
56
+ """
57
+ Resize low-resolution data to high-res for mixed dataset training
58
+ """
59
+ self.scale_x = scale_x
60
+ self.scale_y = scale_y
61
+ self.nearest_interp = nearest_interp
62
+
63
+ def __call__(self, sample):
64
+ scale_x = self.scale_x
65
+ scale_y = self.scale_y
66
+
67
+ sample['left'] = cv2.resize(sample['left'], None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
68
+ sample['right'] = cv2.resize(sample['right'], None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
69
+
70
+ if 'disp' in sample.keys():
71
+ sample['disp'] = cv2.resize(
72
+ sample['disp'], None, fx=scale_x, fy=scale_y,
73
+ interpolation=cv2.INTER_LINEAR if not self.nearest_interp else cv2.INTER_NEAREST
74
+ ) * scale_x
75
+
76
+ if 'disp_r' in sample.keys():
77
+ sample['disp_r'] = cv2.resize(
78
+ sample['disp_r'], None, fx=scale_x, fy=scale_y,
79
+ interpolation=cv2.INTER_LINEAR if not self.nearest_interp else cv2.INTER_NEAREST
80
+ ) * scale_x
81
+
82
+ return sample
gradio_app.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import argparse
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import os
7
+ import time
8
+ import spaces
9
+
10
+ from dataloader.stereo import transforms
11
+ from utils.utils import InputPadder, calc_noc_mask
12
+ from huggingface_hub import hf_hub_download
13
+ from models.match_stereo import MatchStereo
14
+
15
+ torch.backends.cudnn.benchmark = True
16
+
17
+ class MatchStereoDemo:
18
+ def __init__(self):
19
+ self.has_cuda = torch.cuda.is_available()
20
+ self.device = torch.device('cuda:0') if self.has_cuda else 'cpu'
21
+ self.model = None
22
+ self.current_variant = None
23
+ self.current_mode = None
24
+ self.current_precision = None
25
+ self.current_mat_impl = None
26
+ self.download_model()
27
+
28
+ def download_model(self):
29
+ REPO_ID = 'Tingman/MatchAttention'
30
+ filename_list = ['matchstereo_tiny_fsd.pth', 'matchstereo_small_fsd.pth', 'matchstereo_base_fsd.pth', 'matchflow_base_sintel.pth']
31
+ if not os.path.exists('./checkpoints/'):
32
+ os.makedirs('./checkpoints/')
33
+ for filename in filename_list:
34
+ local_file = os.path.join('./checkpoints/', filename)
35
+ if not os.path.exists(local_file):
36
+ hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/', local_dir_use_symlinks=False)
37
+
38
+ def load_model(self, mode, variant, precision, mat_impl):
39
+ """load model, skip if the model has been loaded"""
40
+ if (self.model is not None and
41
+ self.current_variant == variant and
42
+ self.current_mode == mode and
43
+ self.current_precision == precision and
44
+ self.current_mat_impl == mat_impl):
45
+ return "Model already loaded"
46
+
47
+ # fixed checkpoint path
48
+ checkpoint_base_path = "./checkpoints"
49
+ if mode == 'stereo':
50
+ checkpoint_name = f"match{mode}_{variant}_fsd.pth"
51
+ elif mode == 'flow':
52
+ checkpoint_name = f"match{mode}_{variant}_sintel.pth"
53
+ else:
54
+ raise NotImplementedError
55
+
56
+ checkpoint_path = os.path.join(checkpoint_base_path, checkpoint_name)
57
+
58
+ if not os.path.exists(checkpoint_path):
59
+ return f"Error: Checkpoint not found at {checkpoint_path}"
60
+
61
+ args = argparse.Namespace()
62
+ args.mode = mode
63
+ args.variant = variant
64
+ args.mat_impl = mat_impl
65
+
66
+ if not self.has_cuda:
67
+ precision = "fp32"
68
+ dtypes = {'fp32': torch.float32, 'fp16': torch.float16}
69
+ self.dtype = dtypes[precision]
70
+
71
+ self.model = MatchStereo(args)
72
+
73
+ try:
74
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
75
+ self.model.load_state_dict(state_dict=checkpoint['model'], strict=False)
76
+ self.model.to(self.device)
77
+ self.model.eval()
78
+ self.model = self.model.to(self.dtype)
79
+
80
+ self._warmup_model()
81
+
82
+ self.current_variant = variant
83
+ self.current_mode = mode
84
+ self.current_precision = precision
85
+ self.current_mat_impl = mat_impl
86
+
87
+ device_info = "GPU" if self.has_cuda else "CPU"
88
+ return f"Successfully loaded {mode} {variant} model on {device_info} (precision: {precision}, mat_impl: {mat_impl})"
89
+ except Exception as e:
90
+ return f"Error loading model: {str(e)}"
91
+
92
+ def _warmup_model(self):
93
+ """warmup the model for accurate time measurement"""
94
+ if self.model is None:
95
+ return
96
+
97
+ dummy_left = torch.randn(1, 3, 512, 512, device=self.device, dtype=self.dtype)
98
+ dummy_right = torch.randn(1, 3, 512, 512, device=self.device, dtype=self.dtype)
99
+
100
+ with torch.no_grad():
101
+ _ = self.model(dummy_left, dummy_right, stereo=(self.current_mode == 'stereo'))
102
+
103
+ def run_frame(self, left, right, stereo, low_res_init=False, factor=2.):
104
+ """single frame inference"""
105
+ if low_res_init:
106
+ left_ds = F.interpolate(left, scale_factor=1/factor, mode='bilinear', align_corners=True)
107
+ right_ds = F.interpolate(right, scale_factor=1/factor, mode='bilinear', align_corners=True)
108
+ padder_ds = InputPadder(left_ds.shape, padding_factor=32)
109
+ left_ds, right_ds = padder_ds.pad(left_ds, right_ds)
110
+
111
+ field_up_ds = self.model(left_ds, right_ds, stereo=stereo)['field_up']
112
+ field_up_ds = padder_ds.unpad(field_up_ds.permute(0, 3, 1, 2).contiguous()).contiguous()
113
+ field_up_init = F.interpolate(field_up_ds, scale_factor=factor/32, mode='bilinear', align_corners=True)*(factor/32)
114
+ field_up_init = field_up_init.permute(0, 2, 3, 1).contiguous()
115
+ results_dict = self.model(left, right, stereo=stereo, init_flow=field_up_init)
116
+ else:
117
+ results_dict = self.model(left, right, stereo=stereo)
118
+
119
+ return results_dict
120
+
121
+ def get_inference_size(self, size_name):
122
+ if size_name == "Original":
123
+ return None
124
+
125
+ def round_to_32(x):
126
+ return (x + 16) // 32 * 32
127
+
128
+ size_presets = {
129
+ "720P": (round_to_32(1280), round_to_32(720)),
130
+ "1080P": (round_to_32(1920), round_to_32(1080)),
131
+ "2K": (round_to_32(2048), round_to_32(1080)),
132
+ "4K UHD": (round_to_32(3840), round_to_32(2160))
133
+ }
134
+
135
+ return size_presets.get(size_name, None)
136
+
137
+ def process_images(self, left_image, right_image, mode, variant,
138
+ low_res_init=False, inference_size_name="Original",
139
+ precision="fp32", mat_impl="pytorch"):
140
+ if not self.has_cuda:
141
+ precision = "fp32"
142
+ mat_impl = "pytorch"
143
+
144
+ load_result = self.load_model(mode, variant, precision, mat_impl)
145
+ if load_result.startswith("Error"):
146
+ return None, None, None, load_result
147
+
148
+ try:
149
+ left = np.array(left_image.convert('RGB')).astype(np.float32)
150
+ right = np.array(right_image.convert('RGB')).astype(np.float32)
151
+
152
+ original_size = left.shape[:2] # (H, W)
153
+
154
+ inference_size = self.get_inference_size(inference_size_name)
155
+
156
+ val_transform_list = [transforms.ToTensor(no_normalize=True)]
157
+ val_transform = transforms.Compose(val_transform_list)
158
+
159
+ sample = {'left': left, 'right': right}
160
+ sample = val_transform(sample)
161
+ left_tensor = sample['left'].to(self.device, dtype=self.dtype).unsqueeze(0)
162
+ right_tensor = sample['right'].to(self.device, dtype=self.dtype).unsqueeze(0)
163
+
164
+ stereo = (mode == 'stereo')
165
+
166
+ ori_size = left_tensor.shape[-2:]
167
+ if inference_size is not None:
168
+ left_tensor = F.interpolate(left_tensor, size=inference_size, mode='bilinear', align_corners=True)
169
+ right_tensor = F.interpolate(right_tensor, size=inference_size, mode='bilinear', align_corners=True)
170
+ padder = None
171
+ else:
172
+ padder = InputPadder(left_tensor.shape, padding_factor=32)
173
+ left_tensor, right_tensor = padder.pad(left_tensor, right_tensor)
174
+
175
+ device_type = "GPU" if self.has_cuda else "CPU"
176
+ actual_size = inference_size if inference_size else ori_size
177
+ status_info = f"Device: {device_type} | Resolution: {actual_size[1]}x{actual_size[0]} | Precision: {precision}"
178
+
179
+ start_time = time.time()
180
+ with torch.no_grad():
181
+ results_dict = self.run_frame(left_tensor, right_tensor, stereo, low_res_init)
182
+ inference_time = (time.time() - start_time) * 1000 # ms
183
+
184
+ field_up = results_dict['field_up'].permute(0, 3, 1, 2).float().contiguous()
185
+
186
+ if padder is not None:
187
+ field_up = padder.unpad(field_up)
188
+ elif inference_size is not None:
189
+ field_up = F.interpolate(field_up, size=ori_size, mode='bilinear', align_corners=True)
190
+ field_up[:, 0] = field_up[:, 0] * (ori_size[1] / float(inference_size[1]))
191
+ field_up[:, 1] = field_up[:, 1] * (ori_size[0] / float(inference_size[0]))
192
+
193
+ noc_mask = calc_noc_mask(field_up.permute(0, 2, 3, 1), A=8)
194
+ noc_mask = noc_mask[0].detach().cpu().numpy()
195
+ noc_mask = np.where(noc_mask, 255, 128).astype(np.uint8)
196
+
197
+ field_up = torch.cat((field_up, torch.zeros_like(field_up[:, :1])), dim=1)
198
+ field_up = field_up.permute(0, 2, 3, 1).contiguous()
199
+ field, field_r = field_up.chunk(2, dim=0)
200
+
201
+ if stereo:
202
+ disparity = (-field[..., 0]).clamp(min=0)
203
+
204
+ disparity_np = disparity[0].detach().cpu().numpy()
205
+ min_val = disparity_np.min()
206
+ max_val = disparity_np.max()
207
+ if max_val - min_val > 1e-6:
208
+ disparity_norm = (disparity_np - min_val) / (max_val - min_val)
209
+ else:
210
+ disparity_norm = np.zeros_like(disparity_np)
211
+ disparity_img = (disparity_norm * 255).astype(np.uint8)
212
+
213
+ return disparity_img, noc_mask, f"Inference time: {inference_time:.2f} ms. (Please re-run to get accurate time.)", status_info
214
+ else:
215
+ flow = field[0].detach().cpu().numpy()
216
+ flow_rgb = self.flow_to_color(flow)
217
+ return flow_rgb, noc_mask, f"Inference time: {inference_time:.2f} ms. (Please re-run to get accurate time.)", status_info
218
+
219
+ except Exception as e:
220
+ device_type = "GPU" if self.has_cuda else "CPU"
221
+ return None, None, f"Error during inference: {str(e)}", f"Device: {device_type} | Error occurred"
222
+
223
+ def flow_to_color(self, flow):
224
+ """visualization of flow"""
225
+ u = flow[..., 0]
226
+ v = flow[..., 1]
227
+
228
+ rad = np.sqrt(u**2 + v**2)
229
+ rad_max = np.max(rad)
230
+ epsilon = 1e-8
231
+
232
+ if rad_max > epsilon:
233
+ u = u / (rad_max + epsilon)
234
+ v = v / (rad_max + epsilon)
235
+
236
+ h, w = u.shape
237
+ hsv = np.zeros((h, w, 3), dtype=np.uint8)
238
+ hsv[..., 1] = 255
239
+
240
+ mag, ang = cv2.cartToPolar(u, v)
241
+ hsv[..., 0] = ang * 180 / np.pi / 2
242
+ hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
243
+
244
+ flow_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
245
+ return flow_rgb
246
+
247
+ demo_model = MatchStereoDemo()
248
+
249
+ # example images
250
+ examples = [
251
+ ["examples/booster_bathroom_left.png", "examples/booster_bathroom_right.png", "stereo", "tiny"],
252
+ ["examples/staircase_q_left.png", "examples/staircase_q_right.png", "stereo", "tiny"],
253
+ ["examples/frame_0031_clean.png", "examples/frame_0032_clean.png", "flow", "base"],
254
+ ]
255
+
256
+ @spaces.GPU
257
+ def process_inference(left_img, right_img, mode, variant,
258
+ low_res_init, inference_size, precision, mat_impl):
259
+ """Gradio function"""
260
+ if left_img is None or right_img is None:
261
+ return None, None, "Please upload both left and right images", "Waiting for input..."
262
+
263
+ try:
264
+ result = demo_model.process_images(
265
+ left_img, right_img, mode, variant,
266
+ low_res_init, inference_size, precision, mat_impl
267
+ )
268
+ return result
269
+ except Exception as e:
270
+ return None, None, f"Error during inference: {str(e)}", f"Error: {str(e)}"
271
+
272
+ def update_variant_choices(mode):
273
+ if mode == "flow":
274
+ return gr.Radio(choices=["base"], value="base")
275
+ else:
276
+ return gr.Radio(choices=["tiny", "small", "base"], value="tiny")
277
+
278
+ # Gradio UI
279
+ with gr.Blocks(title="MatchStereo/MatchFlow Demo") as demo:
280
+ gr.Markdown("# MatchStereo/MatchFlow Demo")
281
+ gr.Markdown("Upload stereo images for disparity estimation or consecutive frames for optical flow estimation.")
282
+
283
+ if not demo_model.has_cuda:
284
+ gr.Markdown("> Note: Running on CPU. Some options (fp16, cuda) are disabled.")
285
+
286
+ with gr.Row():
287
+ with gr.Column():
288
+ left_image = gr.Image(label="Left Image / Frame 1", type="pil")
289
+ right_image = gr.Image(label="Right Image / Frame 2", type="pil")
290
+
291
+ with gr.Row():
292
+ mode = gr.Radio(
293
+ choices=["stereo", "flow"],
294
+ label="Mode",
295
+ value="stereo",
296
+ info="Select stereo for disparity estimation or flow for optical flow"
297
+ )
298
+ variant = gr.Radio(
299
+ choices=["tiny", "small", "base"],
300
+ label="Model Variant",
301
+ value="tiny",
302
+ info="Model size variant"
303
+ )
304
+
305
+ with gr.Row():
306
+ low_res_init = gr.Checkbox(
307
+ label="Low Resolution Init",
308
+ value=False,
309
+ info="Use low-resolution initialization for high-res images (>=2K)"
310
+ )
311
+ inference_size = gr.Dropdown(
312
+ choices=["Original", "720P", "1080P", "2K", "4K UHD"],
313
+ label="Inference Size",
314
+ value="Original",
315
+ info="Rounded to multiples of 32"
316
+ )
317
+
318
+ with gr.Row():
319
+ precision = gr.Radio(
320
+ choices=["fp32", "fp16"],
321
+ label="Precision",
322
+ value="fp32",
323
+ info="Model precision",
324
+ interactive=demo_model.has_cuda
325
+ )
326
+ mat_impl = gr.Radio(
327
+ choices=["cuda", "pytorch"],
328
+ label="MatchAttention Implementation",
329
+ value="cuda",
330
+ info="MatchAttention implementations",
331
+ interactive=demo_model.has_cuda
332
+ )
333
+
334
+ run_btn = gr.Button("Run Inference", variant="primary")
335
+
336
+ with gr.Column():
337
+ output_image = gr.Image(label="Output Result", interactive=False)
338
+ noc_mask = gr.Image(label="NOC Mask", interactive=False)
339
+ time_output = gr.Textbox(label="Inference Time", interactive=False)
340
+ status = gr.Textbox(label="Status Info", interactive=False, lines=2)
341
+
342
+ gr.Markdown("## Examples")
343
+ gr.Examples(
344
+ examples=examples,
345
+ inputs=[left_image, right_image, mode, variant],
346
+ outputs=[output_image, noc_mask, time_output, status],
347
+ fn=process_inference,
348
+ cache_examples=False,
349
+ label="Click any example below to load it"
350
+ )
351
+
352
+ run_btn.click(
353
+ fn=process_inference,
354
+ inputs=[left_image, right_image, mode, variant,
355
+ low_res_init, inference_size, precision, mat_impl],
356
+ outputs=[output_image, noc_mask, time_output, status]
357
+ )
358
+
359
+ mode.change(
360
+ fn=update_variant_choices,
361
+ inputs=[mode],
362
+ outputs=[variant]
363
+ )
364
+
365
+ if __name__ == "__main__":
366
+ try:
367
+ import cv2
368
+ except ImportError:
369
+ print("Please install OpenCV for optical flow visualization: pip install opencv-python")
370
+
371
+ demo.launch()
models/__init__.py ADDED
File without changes
models/attention_blocks.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from timm.models.layers import DropPath
5
+
6
+ from models.convformer import LayerNormWithoutBias
7
+ from models.common import ConvGLU
8
+ from models.mat_pytorch_impl import compute_bilinear_weights, compute_match_attention, compute_bilinear_softmax, attention_aggregate
9
+ from models.match_former_ops import MF_FusedForwardOps
10
+ from utils.utils import bilinear_sample_by_offset, init_coords
11
+
12
+ class MatchAttention(torch.nn.Module):
13
+ r"""MatchAttention: Matching the relative positions
14
+ """
15
+ def __init__(self, args, dim, win_r=[1, 1], num_head=8, head_dim=None, qkv_bias=False,
16
+ attn_drop=0., proj_drop=0., proj_bias=False, cross=False, noc_embed=False, **kargs):
17
+ super().__init__()
18
+
19
+ self.num_head = num_head
20
+ self.cross = cross
21
+ self.noc_embed = noc_embed if not cross else False # only for self attention
22
+
23
+ self.head_dim = dim // num_head if head_dim is None else head_dim
24
+ self.scale = self.head_dim ** -0.5
25
+
26
+ self.attention_dim = self.num_head * self.head_dim
27
+
28
+ self.win_r = win_r
29
+ self.attn_num = (2*win_r[0]+2)*(2*win_r[1]+2)
30
+
31
+ embed_dim = dim + 1 if noc_embed else dim # '1' for noc_mask
32
+ self.q = nn.Linear(embed_dim, self.attention_dim, bias=qkv_bias)
33
+ self.k = nn.Linear(embed_dim, self.attention_dim, bias=qkv_bias)
34
+ self.v = nn.Linear(embed_dim, self.attention_dim, bias=qkv_bias)
35
+ self.attn_drop = nn.Dropout(attn_drop)
36
+ if self.cross:
37
+ self.g = nn.Sequential(nn.Linear(embed_dim, self.attention_dim,bias=qkv_bias), nn.SiLU())
38
+ self.proj = nn.Linear(self.attention_dim + self.num_head*self.attn_num, dim, bias=proj_bias)
39
+ else:
40
+ self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias)
41
+ self.proj_drop = nn.Dropout(proj_drop)
42
+ self.use_pytorch = (args.mat_impl == 'pytorch')
43
+ self.mf_fused = MF_FusedForwardOps()
44
+
45
+ def clamp_max_offset(self, max_offset, H, W):
46
+ max_offset_x, max_offset_y = max_offset.chunk(2, dim=-1) # to avoid inplace operation
47
+
48
+ # for ONNX support
49
+ min_x = torch.tensor(self.win_r[0], dtype=max_offset.dtype, device=max_offset.device)
50
+ max_x = torch.tensor(W - 1 - self.win_r[0] - 1e-3, dtype=max_offset.dtype, device=max_offset.device)
51
+ min_y = torch.tensor(self.win_r[1], dtype=max_offset.dtype, device=max_offset.device)
52
+ max_y = torch.tensor(H - 1 - self.win_r[1] - 1e-3, dtype=max_offset.dtype, device=max_offset.device)
53
+
54
+ max_offset_x = torch.clamp(max_offset_x, min=min_x, max=max_x)
55
+ max_offset_y = torch.clamp(max_offset_y, min=min_y, max=max_y)
56
+
57
+ ## max_offset_x = max_offset_x.clamp(min=self.win_r[0], max=W-1-self.win_r[0]-1e-3)
58
+ ## max_offset_y = max_offset_y.clamp(min=self.win_r[1], max=H-1-self.win_r[1]-1e-3)
59
+ return torch.cat((max_offset_x, max_offset_y), dim=-1).contiguous()
60
+
61
+ def forward(self, x, max_offset, noc_mask=None): # offset: [B, N, h, 2]
62
+ B, H, W, _ = x.shape
63
+ N = H*W
64
+ assert (2*self.win_r[1] + 2 <= H) and (2*self.win_r[0] + 2 <= W)
65
+ x = x.view(B, N, -1).contiguous()
66
+
67
+ if self.cross:
68
+ ref_, tgt_ = x.chunk(2, dim=0) # split along batch dimension
69
+ ref = torch.cat((ref_, tgt_), dim=0) # order
70
+ tgt = torch.cat((tgt_, ref_), dim=0) # reverse order
71
+ g = self.g(ref)
72
+ else: # self-attn
73
+ if self.noc_embed:
74
+ x = torch.cat((x, noc_mask.view(B, N, -1)), dim=-1).contiguous()
75
+ ref, tgt = x, x
76
+ q, k, v = self.q(ref), self.k(tgt), self.v(tgt)
77
+
78
+ ## non-parameter modules
79
+ max_offset = self.clamp_max_offset(max_offset, H, W)
80
+
81
+ if self.use_pytorch:
82
+ m_id = torch.floor(max_offset).to(torch.int32) # [B, N, h, 2]
83
+ bilinear_weight = compute_bilinear_weights(max_offset)
84
+
85
+ attn, indices_gather = compute_match_attention(q.view(B, N, self.num_head, -1), k.view(B, N, self.num_head, -1), m_id, self.win_r, H, W)
86
+ attn = attn * self.scale
87
+
88
+ attn = compute_bilinear_softmax(attn, bilinear_weight, self.win_r)
89
+ attn = self.attn_drop(attn)
90
+
91
+ x = attention_aggregate(v.view(B, N, self.num_head, -1), attn, indices_gather, self.win_r)
92
+ else:
93
+ x, attn = self.mf_fused(max_offset, q, k, v, H, W, self.win_r, self.attn_num, attn_type='l1_norm', scale=self.scale)
94
+
95
+ if self.cross:
96
+ x = g * x # gate
97
+ attn = attn.view(B, N, -1).contiguous()
98
+ x = torch.cat((x, attn), dim=-1).contiguous()
99
+ x = self.proj(x)
100
+ x = self.proj_drop(x)
101
+ return x.view(B, H, W, -1).contiguous()
102
+
103
+
104
+ class MatchAttentionLayer(nn.Module):
105
+ r"""MatchAttention layer with interleaved self-MatchAttention, cross-MatchAttention, and ConvGLU
106
+ """
107
+
108
+ def __init__(self, args, dim, win_r,
109
+ num_head=8, head_dim=32, mlp=ConvGLU, mlp_ratio=2, field_dim=2,
110
+ norm_layer=nn.LayerNorm, drop=0., drop_path=0.):
111
+ super().__init__()
112
+ self.num_head = num_head
113
+ self.field_dim = field_dim
114
+
115
+ self.match_attention_self = MatchAttention(args, dim + self.field_dim + self.num_head*2, [win_r, win_r], num_head=num_head, head_dim=head_dim, noc_embed=True)
116
+ self.norm0 = norm_layer(dim + self.field_dim + self.num_head*2)
117
+
118
+ self.match_attention_cross = MatchAttention(args, dim + self.field_dim, [win_r, win_r], num_head=num_head, head_dim=head_dim, cross=True)
119
+ self.norm1 = norm_layer(dim + self.field_dim)
120
+
121
+ self.mlp = mlp(dim=dim, mlp_ratio=mlp_ratio, drop=drop)
122
+ self.norm2 = norm_layer(dim)
123
+
124
+ self.field_scale = nn.Parameter(0.1*torch.ones(1, 1, 1, 2))
125
+
126
+ self.drop_path0 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
127
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
128
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
129
+
130
+ def consistency_mask(self, field, A=2):
131
+ offset = field + init_coords(field) # [B, H, W, 2]
132
+ field_ref_, field_tgt_ = field.chunk(2, dim=0)
133
+ field_ref = torch.cat((field_ref_, field_tgt_), dim=0) # order
134
+ field_tgt = torch.cat((field_tgt_, field_ref_), dim=0) # reverse order
135
+ field_tgt_to_ref = bilinear_sample_by_offset(field_tgt.permute(0, 3, 1, 2).contiguous(), offset).permute(0, 2, 3, 1).contiguous()
136
+ field_diff = torch.abs(field_ref + field_tgt_to_ref).sum(dim=-1, keepdim=True) # ref and tgt flow has different sign
137
+ noc_mask = (field_diff < A).to(field_diff.dtype)
138
+ return noc_mask
139
+
140
+ def forward(self, x, self_rpos, field, stereo=True): # self_rpos [B, H, W, h*2], field [B, H, W, 2]
141
+
142
+ field_out = {}
143
+ B, H, W, C = x.shape
144
+
145
+ noc_mask = self.consistency_mask(field.detach())
146
+
147
+ x = torch.cat((x, field*self.field_scale.to(field.dtype), self_rpos), dim=-1).contiguous()
148
+
149
+ coords_0 = init_coords(field).repeat(1, 1, 1, self.num_head)
150
+ self_offset = self_rpos + coords_0
151
+ self_offset = self_offset.view(B, H*W, self.num_head, 2).contiguous()
152
+
153
+ x = x + self.drop_path0(self.match_attention_self(self.norm0(x), self_offset, noc_mask))
154
+
155
+ self_rpos = x[..., -(self.num_head*2):].contiguous() # [B, H, W, h*2]
156
+ x = x[..., :-(self.num_head*2)].contiguous()
157
+
158
+ if stereo: x[..., -1] = 0
159
+ field = x[..., -self.field_dim:].contiguous() / self.field_scale.to(field.dtype)
160
+ field_out['self'] = field.clone()
161
+
162
+ offset = field.repeat(1, 1, 1, self.num_head).contiguous() + coords_0 # [B, H, W, h*2]
163
+ offset = offset.view(B, H*W, self.num_head, 2).contiguous()
164
+
165
+ x = x + self.drop_path1(self.match_attention_cross(self.norm1(x), offset))
166
+
167
+ if stereo: x[..., -1] = 0
168
+ field = x[..., -self.field_dim:].contiguous() / self.field_scale.to(field.dtype)
169
+ field_out['cross'] = field.clone()
170
+
171
+ x = x[..., :-self.field_dim].contiguous() # No field feature in MLP
172
+
173
+ x = x + self.drop_path2(self.mlp(self.norm2(x)))
174
+
175
+ return x, self_rpos, field, field_out
176
+
177
+
178
+ class MatchAttentionBlock(nn.Module):
179
+ r"""MatchAttention block with multiple match-attention layers
180
+ """
181
+
182
+ def __init__(self, args, dim, win_r=2,
183
+ num_layer=6, num_head=8, head_dim=32,
184
+ mlp=ConvGLU, mlp_ratio=2, field_dim=2,
185
+ norm_layer=LayerNormWithoutBias,
186
+ drop=0., dp_rates=[0.]):
187
+
188
+ super().__init__()
189
+ self.num_head = num_head
190
+
191
+ self.layers = nn.ModuleList()
192
+ for i in range(num_layer):
193
+ layer = MatchAttentionLayer(args, dim, win_r=win_r, num_head=num_head, head_dim=head_dim,
194
+ mlp=mlp, mlp_ratio=mlp_ratio, field_dim=field_dim,
195
+ norm_layer=norm_layer, drop=drop, drop_path=dp_rates[i])
196
+ self.layers.append(layer)
197
+
198
+ def forward(self, x, self_rpos, field, stereo=True):
199
+ fields = []
200
+ B, H, W, C = x.shape
201
+ self_rpos = self_rpos.repeat(1, 1, 1, self.num_head) # [B, H, W, 2] -> [B, H, W, h*2]
202
+
203
+ for layer in self.layers:
204
+
205
+ x, self_rpos, field, field_out = layer(x, self_rpos, field, stereo)
206
+ fields.append(field_out)
207
+
208
+ self_rpos = self_rpos.view(B, H, W, self.num_head, 2).mean(dim=-2, keepdim=False)
209
+
210
+ return x, self_rpos, field, fields
models/common.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class UpConv(nn.Module):
6
+ r"""Upsample using transposed conv"""
7
+
8
+ def __init__(self, in_channels, out_channels):
9
+ super().__init__()
10
+
11
+ self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0)
12
+ self.conv = nn.Sequential(
13
+ nn.Conv2d(out_channels*2, out_channels, kernel_size=1, padding=0),
14
+ nn.ReLU(inplace=True),
15
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
16
+ )
17
+
18
+ def forward(self, x1, x2, use_up=True):
19
+ x1 = x1.permute(0, 3, 1, 2).contiguous()
20
+ x2 = x2.permute(0, 3, 1, 2).contiguous()
21
+ if use_up:
22
+ x1 = self.up(x1)
23
+ x = torch.cat([x2, x1], dim=1)
24
+ out = self.conv(x)
25
+ return out.permute(0, 2, 3, 1).contiguous() # [B, H, W, C]
26
+
27
+ class ConvGLU(nn.Module):
28
+ '''
29
+ Convolutional GLU, referenced from TransNeXt
30
+ '''
31
+ def __init__(self, dim, mlp_ratio=2, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
32
+ super().__init__()
33
+ in_features = dim
34
+ out_features = out_features or in_features
35
+ hidden_features = int(mlp_ratio * in_features)
36
+ self.fc1 = nn.Linear(in_features, hidden_features * 2)
37
+ self.dwconv = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, bias=True, groups=hidden_features)
38
+ self.act = act_layer()
39
+ self.fc2 = nn.Linear(hidden_features, out_features)
40
+ self.drop = nn.Dropout(drop)
41
+
42
+ def forward(self, x): # [B, H, W, C]
43
+ x, v = self.fc1(x).chunk(2, dim=-1)
44
+ x = self.act(self.dwconv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1).contiguous()) * v
45
+ x = self.drop(x)
46
+ x = self.fc2(x)
47
+ x = self.drop(x)
48
+ return x
models/compile.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/bin/bash
2
+ rm -rf build/ dist/ match_attention.egg-info/ __pycache__
3
+ python setup.py clean
4
+ pip install .
models/convformer.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from timm.models.layers import trunc_normal_, DropPath
7
+ from timm.models.registry import register_model
8
+ from timm.models.layers.helpers import to_2tuple
9
+ class LayerNormGeneral(nn.Module):
10
+ r""" General LayerNorm for different situations.
11
+
12
+ Args:
13
+ affine_shape (int, list or tuple): The shape of affine weight and bias.
14
+ Usually the affine_shape=C, but in some implementation, like torch.nn.LayerNorm,
15
+ the affine_shape is the same as normalized_dim by default.
16
+ To adapt to different situations, we offer this argument here.
17
+ normalized_dim (tuple or list): Which dims to compute mean and variance.
18
+ scale (bool): Flag indicates whether to use scale or not.
19
+ bias (bool): Flag indicates whether to use scale or not.
20
+
21
+ We give several examples to show how to specify the arguments.
22
+
23
+ LayerNorm (https://arxiv.org/abs/1607.06450):
24
+ For input shape of (B, *, C) like (B, N, C) or (B, H, W, C),
25
+ affine_shape=C, normalized_dim=(-1, ), scale=True, bias=True;
26
+ For input shape of (B, C, H, W),
27
+ affine_shape=(C, 1, 1), normalized_dim=(1, ), scale=True, bias=True.
28
+
29
+ Modified LayerNorm (https://arxiv.org/abs/2111.11418)
30
+ that is idental to partial(torch.nn.GroupNorm, num_groups=1):
31
+ For input shape of (B, N, C),
32
+ affine_shape=C, normalized_dim=(1, 2), scale=True, bias=True;
33
+ For input shape of (B, H, W, C),
34
+ affine_shape=C, normalized_dim=(1, 2, 3), scale=True, bias=True;
35
+ For input shape of (B, C, H, W),
36
+ affine_shape=(C, 1, 1), normalized_dim=(1, 2, 3), scale=True, bias=True.
37
+
38
+ For the several metaformer baslines,
39
+ IdentityFormer, RandFormer and PoolFormerV2 utilize Modified LayerNorm without bias (bias=False);
40
+ ConvFormer and CAFormer utilizes LayerNorm without bias (bias=False).
41
+ """
42
+
43
+ def __init__(self, affine_shape=None, normalized_dim=(-1, ), scale=True,
44
+ bias=False, eps=1e-6):
45
+ super().__init__()
46
+ self.normalized_dim = normalized_dim
47
+ self.use_scale = scale
48
+ self.use_bias = bias
49
+ self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else None
50
+ self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else None
51
+ self.eps = eps
52
+
53
+ def forward(self, x):
54
+ c = x - x.mean(self.normalized_dim, keepdim=True)
55
+ s = c.pow(2).mean(self.normalized_dim, keepdim=True)
56
+ x = c / torch.sqrt(s + self.eps)
57
+ if self.use_scale:
58
+ x = x * self.weight
59
+ if self.use_bias:
60
+ x = x + self.bias
61
+ return x
62
+
63
+
64
+
65
+ def stem(in_chs, out_chs, act_layer=nn.GELU):
66
+ return nn.Sequential(
67
+ nn.Conv2d(in_chs, out_chs // 2, kernel_size=3, stride=2, padding=1),
68
+ ## nn.BatchNorm2d(out_chs // 2),
69
+ nn.InstanceNorm2d(out_chs // 2),
70
+ act_layer(),
71
+ nn.Conv2d(out_chs // 2, out_chs, kernel_size=3, stride=2, padding=1),
72
+ ## nn.BatchNorm2d(out_chs),
73
+ nn.InstanceNorm2d(out_chs),
74
+ act_layer(),
75
+ )
76
+
77
+ class Downsampling(nn.Module):
78
+ """
79
+ Downsampling implemented by a layer of convolution.
80
+ """
81
+
82
+ def __init__(self, in_channels, out_channels,
83
+ kernel_size=3, stride=2, padding=1,
84
+ pre_norm=LayerNormGeneral, post_norm=None, pre_permute=True):
85
+ super().__init__()
86
+ self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity()
87
+ self.pre_permute = pre_permute
88
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
89
+ stride=stride, padding=padding)
90
+ self.post_norm = post_norm(
91
+ out_channels) if post_norm else nn.Identity()
92
+
93
+ def forward(self, x):
94
+ x = self.pre_norm(x)
95
+ if self.pre_permute:
96
+ x = x.permute(0, 3, 1, 2).contiguous() # if take [B, H, W, C] as input, permute it to [B, C, H, W]
97
+ x = self.conv(x)
98
+ x = x.permute(0, 2, 3, 1).contiguous() # [B, C, H, W] -> [B, H, W, C]
99
+ x = self.post_norm(x)
100
+ return x
101
+
102
+
103
+ class Scale(nn.Module):
104
+ """
105
+ Scale vector by element multiplications.
106
+ """
107
+
108
+ def __init__(self, dim, init_value=1.0, trainable=True):
109
+ super().__init__()
110
+ self.scale = nn.Parameter(
111
+ init_value * torch.ones(dim), requires_grad=trainable)
112
+
113
+ def forward(self, x):
114
+ return x * self.scale
115
+
116
+
117
+ class LayerNormWithoutBias(nn.Module):
118
+ """
119
+ Equal to partial(LayerNormGeneral, bias=False) but faster,
120
+ because it directly utilizes otpimized F.layer_norm
121
+ """
122
+
123
+ def __init__(self, normalized_shape, eps=1e-5, **kwargs):
124
+ super().__init__()
125
+ self.eps = eps
126
+ self.bias = None
127
+ if isinstance(normalized_shape, int):
128
+ normalized_shape = (normalized_shape,)
129
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
130
+ self.normalized_shape = normalized_shape
131
+
132
+ def forward(self, x):
133
+ return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
134
+
135
+
136
+ class SepConv(nn.Module):
137
+ r"""
138
+ Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381.
139
+ """
140
+
141
+ def __init__(self, dim, expansion_ratio=2,
142
+ act1_layer=nn.GELU, act2_layer=nn.Identity,
143
+ bias=False, kernel_size=3, padding=1,
144
+ **kwargs, ):
145
+ super().__init__()
146
+ med_channels = int(expansion_ratio * dim)
147
+ self.pwconv1 = nn.Linear(dim, med_channels, bias=bias)
148
+ self.act1 = act1_layer()
149
+ self.dwconv = nn.Conv2d(
150
+ med_channels, med_channels, kernel_size=kernel_size,
151
+ padding=padding, groups=med_channels, bias=bias) # depthwise conv
152
+ self.act2 = act2_layer()
153
+ self.pwconv2 = nn.Linear(med_channels, dim, bias=bias)
154
+
155
+ def forward(self, x):
156
+ x = self.pwconv1(x)
157
+ x = self.act1(x)
158
+ x = x.permute(0, 3, 1, 2)
159
+ x = self.dwconv(x)
160
+ x = x.permute(0, 2, 3, 1)
161
+ x = self.act2(x)
162
+ x = self.pwconv2(x)
163
+ return x
164
+
165
+ class Mlp(nn.Module):
166
+ """ MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks.
167
+ Mostly copied from timm.
168
+ """
169
+
170
+ def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=nn.GELU, drop=0., bias=False, **kwargs):
171
+ super().__init__()
172
+ in_features = dim
173
+ out_features = out_features or in_features
174
+ hidden_features = int(mlp_ratio * in_features)
175
+ drop_probs = to_2tuple(drop)
176
+
177
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
178
+ self.act = act_layer()
179
+ self.drop1 = nn.Dropout(drop_probs[0])
180
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
181
+ self.drop2 = nn.Dropout(drop_probs[1])
182
+
183
+ def forward(self, x):
184
+ x = self.fc1(x)
185
+ x = self.act(x)
186
+ x = self.drop1(x)
187
+ x = self.fc2(x)
188
+ x = self.drop2(x)
189
+ return x
190
+
191
+
192
+ class MetaFormerBlock(nn.Module):
193
+ """
194
+ Implementation of one MetaFormer block.
195
+ """
196
+
197
+ def __init__(self, dim,
198
+ token_mixer=nn.Identity, mlp=Mlp, mlp_ratio=4,
199
+ norm_layer=nn.LayerNorm, drop=0., drop_path=0.,
200
+ layer_scale_init_value=None, res_scale_init_value=None
201
+ ):
202
+
203
+ super().__init__()
204
+
205
+ self.token_mixer = token_mixer(dim, drop=drop)
206
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
207
+ self.norm1 = norm_layer(dim)
208
+ self.layer_scale1 = Scale(dim=dim, init_value=layer_scale_init_value) \
209
+ if layer_scale_init_value else nn.Identity()
210
+ self.res_scale1 = Scale(dim=dim, init_value=res_scale_init_value) \
211
+ if res_scale_init_value else nn.Identity()
212
+
213
+ self.norm2 = norm_layer(dim)
214
+ self.mlp = mlp(dim=dim, mlp_ratio=mlp_ratio, drop=drop)
215
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
216
+ self.layer_scale2 = Scale(dim=dim, init_value=layer_scale_init_value) \
217
+ if layer_scale_init_value else nn.Identity()
218
+ self.res_scale2 = Scale(dim=dim, init_value=res_scale_init_value) \
219
+ if res_scale_init_value else nn.Identity()
220
+
221
+ def forward(self, x):
222
+ x = x + self.drop_path1(self.token_mixer(self.norm1(x)))
223
+ x = x + self.drop_path2(self.mlp(self.norm2(x)))
224
+ return x
225
+
226
+
227
+ class MetaFormer(nn.Module):
228
+ r""" MetaFormer
229
+ A PyTorch impl of : `MetaFormer Baselines for Vision` -
230
+ https://arxiv.org/abs/2210.13452
231
+
232
+ Args:
233
+ in_chans (int): Number of input image channels. Default: 3.
234
+ num_classes (int): Number of classes for classification head. Default: 1000.
235
+ depths (list or tuple): Number of blocks at each stage. Default: [2, 2, 6, 2].
236
+ dims (int): Feature dimension at each stage. Default: [64, 128, 320, 512].
237
+ downsample_layers: (list or tuple): Downsampling layers before each stage.
238
+ token_mixers (list, tuple or token_fcn): Token mixer for each stage. Default: nn.Identity.
239
+ mlps (list, tuple or mlp_fcn): Mlp for each stage. Default: Mlp.
240
+ norm_layers (list, tuple or norm_fcn): Norm layers for each stage. Default: partial(LayerNormGeneral, eps=1e-6, bias=False).
241
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
242
+ layer_scale_init_values (list, tuple, float or None): Init value for Layer Scale. Default: None.
243
+ None means not use the layer scale. Form: https://arxiv.org/abs/2103.17239.
244
+ res_scale_init_values (list, tuple, float or None): Init value for Layer Scale. Default: [None, None, 1.0, 1.0].
245
+ None means not use the layer scale. From: https://arxiv.org/abs/2110.09456.
246
+ head_fn: classification head. Default: nn.Linear.
247
+ """
248
+
249
+ def __init__(self, in_chans=3, num_classes=1000,
250
+ depths=[2, 2, 6, 2],
251
+ dims=[64, 128, 320, 512],
252
+ downsample_layers=[stem] + [Downsampling]*3,
253
+ token_mixers=nn.Identity,
254
+ mlps=Mlp, mlp_ratio=4,
255
+ norm_layers=partial(LayerNormWithoutBias, eps=1e-6),
256
+ drop_path_rate=0.,
257
+ layer_scale_init_values=None,
258
+ res_scale_init_values=[None, None, 1.0, 1.0],
259
+ head_fn=nn.Linear,
260
+ **kwargs,
261
+ ):
262
+ super().__init__()
263
+ self.num_classes = num_classes
264
+
265
+ if not isinstance(depths, (list, tuple)):
266
+ depths = [depths] # it means the model has only one stage
267
+ if not isinstance(dims, (list, tuple)):
268
+ dims = [dims]
269
+
270
+ self.dims = dims
271
+ self.depths = depths
272
+
273
+ num_stage = len(depths)
274
+ self.num_stage = num_stage
275
+
276
+ down_dims = [in_chans] + dims
277
+ self.downsample_layers = nn.ModuleList([downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(num_stage)])
278
+
279
+ if not isinstance(token_mixers, (list, tuple)):
280
+ token_mixers = [token_mixers] * num_stage
281
+ self.token_mixers = token_mixers
282
+
283
+ if not isinstance(mlps, (list, tuple)):
284
+ mlps = [mlps] * num_stage
285
+
286
+ if not isinstance(norm_layers, (list, tuple)):
287
+ norm_layers = [norm_layers] * num_stage
288
+
289
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
290
+
291
+ if not isinstance(layer_scale_init_values, (list, tuple)):
292
+ layer_scale_init_values = [layer_scale_init_values] * num_stage
293
+ if not isinstance(res_scale_init_values, (list, tuple)):
294
+ res_scale_init_values = [res_scale_init_values] * num_stage
295
+
296
+ self.stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
297
+ cur = 0
298
+ for i in range(num_stage):
299
+ stage = nn.ModuleList(
300
+ [MetaFormerBlock(dim=dims[i], token_mixer=token_mixers[i],
301
+ mlp=mlps[i], mlp_ratio=mlp_ratio, norm_layer=norm_layers[i],
302
+ drop_path=dp_rates[cur + j],
303
+ layer_scale_init_value=layer_scale_init_values[i],
304
+ res_scale_init_value=res_scale_init_values[i],
305
+ ) for j in range(depths[i])]
306
+ )
307
+ self.stages.append(stage)
308
+ cur += depths[i]
309
+
310
+ self.head = head_fn(dims[-1], num_classes)
311
+
312
+ self.apply(self._init_weights)
313
+
314
+ def _init_weights(self, m):
315
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
316
+ trunc_normal_(m.weight, std=.02)
317
+ if m.bias is not None:
318
+ nn.init.constant_(m.bias, 0)
319
+
320
+
321
+ def forward(self, x):
322
+ outs = []
323
+ for i in range(self.num_stage):
324
+ x = self.downsample_layers[i](x)
325
+ if i==0: x = x.permute(0, 2, 3, 1).contiguous() # [B, C, H, W] -> [B, H, W, C]
326
+ for j in range(self.depths[i]):
327
+ x= self.stages[i][j](x)
328
+ outs.append(x) # [B, H, W, C]
329
+ return outs
330
+
331
+ def convformer(variant='tiny'):
332
+ if variant == 'tiny':
333
+ model = convformer_t()
334
+
335
+ elif variant == 'small':
336
+ model = convformer_s()
337
+
338
+ elif variant == 'base':
339
+ model = convformer_b()
340
+
341
+ elif variant == 'large':
342
+ model = convformer_l()
343
+
344
+ else:
345
+ raise NotImplementedError
346
+
347
+ return model
348
+
349
+ @register_model
350
+ def convformer_t(**kwargs):
351
+ model = MetaFormer(
352
+ depths=[2, 2, 6, 2],
353
+ dims=[32, 64, 128, 160],
354
+ mlps=Mlp, mlp_ratio=2,
355
+ token_mixers=[SepConv, SepConv, SepConv, SepConv],
356
+ head_fn=nn.Linear,
357
+ **kwargs)
358
+ return model
359
+
360
+ @register_model
361
+ def convformer_s(**kwargs):
362
+ model = MetaFormer(
363
+ depths=[2, 2, 6, 2],
364
+ dims=[64, 128, 160, 320],
365
+ mlps=Mlp, mlp_ratio=2,
366
+ token_mixers=[SepConv, SepConv, SepConv, SepConv],
367
+ head_fn=nn.Linear,
368
+ **kwargs)
369
+ return model
370
+
371
+ @register_model
372
+ def convformer_b(**kwargs):
373
+ model = MetaFormer(
374
+ depths=[2, 2, 6, 2],
375
+ dims=[128, 256, 320, 512],
376
+ mlps=Mlp, mlp_ratio=2,
377
+ token_mixers=[SepConv, SepConv, SepConv, SepConv],
378
+ head_fn=nn.Linear,
379
+ **kwargs)
380
+ return model
381
+
382
+ @register_model
383
+ def convformer_l(**kwargs):
384
+ model = MetaFormer(
385
+ depths=[2, 2, 6, 2],
386
+ dims=[256, 384, 512, 768],
387
+ mlps=Mlp, mlp_ratio=2,
388
+ token_mixers=[SepConv, SepConv, SepConv, SepConv],
389
+ head_fn=nn.Linear,
390
+ **kwargs)
391
+ return model
models/cost_volume.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from models.convformer import LayerNormWithoutBias
6
+ from utils.utils import init_coords
7
+
8
+ class GlobalCorrelation(nn.Module):
9
+
10
+ def __init__(self, dim):
11
+ super().__init__()
12
+ self.norm = LayerNormWithoutBias(dim)
13
+ self.q = nn.Linear(dim, dim, bias=False)
14
+ self.k = nn.Linear(dim, dim, bias=False)
15
+ self.scale = dim**-0.5
16
+
17
+ def forward(self, x, stereo=True):
18
+ x = self.norm(x)
19
+ ref, tgt = x.chunk(2, dim=0)
20
+ ref, tgt = self.q(ref), self.k(tgt)
21
+ # global correlation on horizontal direction
22
+ B, H, W, C = ref.shape
23
+
24
+ if stereo:
25
+ correlation = torch.matmul(ref, tgt.transpose(-2, -1))*self.scale # [B, H, W, W]
26
+
27
+ # mask subsequent positions to make disparity positive
28
+ mask = torch.triu(torch.ones((W, W), dtype=ref.dtype, device=ref.device), diagonal=1) # [W, W]
29
+ valid_mask = (mask == 0).unsqueeze(0).unsqueeze(0).repeat(B, H, 1, 1) # [B, H, W, W]
30
+
31
+ mask_ = torch.triu(torch.ones((W, W), dtype=ref.dtype, device=ref.device), diagonal=0) # mask for input order [right, left]
32
+ valid_mask_ = (mask_ != 0).unsqueeze(0).unsqueeze(0).repeat(B, H, 1, 1) # upper right
33
+ valid_mask = torch.cat((valid_mask, valid_mask_), dim=0) # [B*2, H, W, W]
34
+ correlation = torch.cat((correlation, correlation.permute(0, 1, 3, 2)), dim=0) # [B*2, H, W, W]
35
+ B = B*2
36
+
37
+ correlation[~valid_mask] = -1e9 if correlation.dtype == torch.float32 else -1e4
38
+
39
+ # build volume from correlation
40
+ D = W # all-pair correlation
41
+ volume = correlation.new_zeros([B, D, H, W])
42
+ for d in range(D): # most time-consuming
43
+ volume[:B//2, d, :, d:] = correlation[:B//2, :, range(d, W), range(W-d)]
44
+ volume[B//2:, d, :, :(W-d)] = correlation[B//2:, :, range(W-d), range(d, W)]
45
+
46
+ volume = F.softmax(volume, dim=1).to(volume.dtype)
47
+
48
+ volume_clone = volume.clone()
49
+ for d in range(D): # fill out of view # second time-consuming
50
+ volume_clone[:B//2, d, :, :d] = volume[:B//2, d, :, d:d+1] # left
51
+ volume_clone[B//2:, d, :, W-1-d:] = volume[B//2:, d, :, W-1-d:(W-d)] # right
52
+
53
+ flow = local_disparity_estimator(volume_clone)
54
+ return flow, volume_clone
55
+ else:
56
+ init_grid = init_coords(ref) # [B, H, W, 2]
57
+ ref = ref.view(B, -1, C) # [B, H*W, C]
58
+ tgt = tgt.view(B, -1, C) # [B, H*W, C]
59
+
60
+ correlation = torch.matmul(ref, tgt.transpose(-2, -1))*self.scale # [B, H*W, H*W]
61
+ correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W]
62
+ init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, H, W, 2]
63
+ B = B * 2
64
+
65
+ prob = F.softmax(correlation, dim=-1).to(correlation.dtype) # [B, H*W, H*W]
66
+
67
+ flow = local_flow_estimator(prob, init_grid)
68
+
69
+ return flow, prob.view(B, H, W, H*W)
70
+
71
+ def local_flow_estimator(prob, init_grid, k=5):
72
+ """
73
+ Flow estimator using weighted sum within local window centered at max prob
74
+ Args:
75
+ prob: normalized correlation volume [B, H*W, H*W]
76
+ init_grid: init coordinate grid [B, H, W, 2]
77
+ k: local window size (odd number)
78
+ Returns:
79
+ flow: optical field [B, H, W, 2]
80
+ """
81
+ B, H, W, _ = init_grid.shape
82
+ r = k // 2
83
+ device = prob.device
84
+
85
+ prob_blur = F.avg_pool2d(prob, kernel_size=k, stride=1, padding=r).view(B, H*W, H*W)
86
+
87
+ max_prob, max_idx = torch.max(prob_blur, dim=-1) # [B, H*W]
88
+ max_idx = max_idx.unsqueeze(-1) # [B, H*W, 1]
89
+ target_coords = init_grid # [B, H, W, 2]
90
+ max_y = max_idx // W # [B, H*W, 1]
91
+ max_x = max_idx % W # [B, H*W, 1]
92
+ max_y = torch.clamp(max_y, r, H-1-r)
93
+ max_x = torch.clamp(max_x, r, W-1-r)
94
+
95
+ yy, xx = torch.meshgrid(torch.arange(-r, r+1, device=device), torch.arange(-r, r+1, device=device), indexing='ij')
96
+ offsets_y = yy.reshape(1, 1, k*k, 1) # [1, 1, k*k, 1]
97
+ offsets_x = xx.reshape(1, 1, k*k, 1) # [1, 1, k*k, 1]
98
+ sample_y = max_y.unsqueeze(2) + offsets_y # [B, H*W, k*k, 1]
99
+ sample_x = max_x.unsqueeze(2) + offsets_x # [B, H*W, k*k, 1]
100
+ sample_y = sample_y.long().squeeze(-1) # [B, H*W, k*k]
101
+ sample_x = sample_x.long().squeeze(-1) # [B, H*W, k*k]
102
+
103
+ batch_idx = torch.arange(B, device=device).view(B, 1, 1).expand(-1, H*W, k*k)
104
+ window_coords = target_coords[batch_idx, sample_y, sample_x] # [B, H*W, k*k, 2]
105
+
106
+ window_indices = sample_y * W + sample_x # [B, H*W, k*k]
107
+ window_probs = torch.gather(prob, dim=-1, index=window_indices) # [B, H*W, k*k]
108
+
109
+ mean_prob = 1.0 / (H * W)
110
+ invalid_mask = window_probs < mean_prob
111
+ window_probs[invalid_mask] = 0
112
+
113
+ window_probs_sum = window_probs.sum(dim=-1, keepdim=True).to(window_probs.dtype)
114
+ window_probs_sum = torch.clamp(window_probs_sum, min=torch.finfo(window_probs_sum.dtype).tiny)
115
+ normalized_probs = window_probs / window_probs_sum # [B, H*W, k*k]
116
+ normalized_probs = normalized_probs.unsqueeze(-1) # [B, H*W, k*k, 1]
117
+ correspondence = torch.sum(normalized_probs * window_coords, dim=2).to(normalized_probs.dtype) # [B, H*W, 2]
118
+ correspondence = correspondence.view(B, H, W, 2) # [B, H, W, 2]
119
+ flow = correspondence - init_grid
120
+
121
+ return flow
122
+
123
+ def local_disparity_estimator(cv, k=5):
124
+ """
125
+ Disparity estimator using weighted sum within local window centered at max prob
126
+ Args:
127
+ cv: cost volume [B, D, H, W]
128
+ k: local window size (odd number)
129
+ Returns:
130
+ flow: [B, H, W, 2]
131
+ """
132
+ B, D, H, W = cv.shape
133
+ r = k // 2
134
+ device = cv.device
135
+
136
+ cv_blur = F.avg_pool1d(cv.permute(0, 2, 3, 1).view(B, -1, D), kernel_size=k, stride=1, padding=r).view(B, H, W, D).permute(0, 3, 1, 2)
137
+
138
+ # find max idx in blured cv
139
+ max_cv, max_idx = torch.max(cv_blur, dim=1) # max_idx: [B, H, W]
140
+ max_idx = max_idx.unsqueeze(1) # [B, 1, H, W]
141
+ max_idx = torch.clamp(max_idx, r, D-1-r) # [B, 1, H, W]
142
+
143
+ offsets = torch.arange(-r, r+1, device=device).view(1, k, 1, 1) # [1, k, 1, 1]
144
+
145
+ sample_idx = max_idx + offsets # [B, k, H, W]
146
+ sample_idx = torch.clamp(sample_idx, 0, D-1)
147
+
148
+ batch_idx = torch.arange(B, device=device).view(B, 1, 1, 1).expand(-1, k, H, W)
149
+ h_idx = torch.arange(H, device=device).view(1, 1, H, 1).expand(B, k, H, W)
150
+ w_idx = torch.arange(W, device=device).view(1, 1, 1, W).expand(B, k, H, W)
151
+
152
+ window_probs = cv[batch_idx, sample_idx, h_idx, w_idx] # [B, k, H, W]
153
+
154
+ mean_prob = 1.0 / D
155
+ invalid_mask = window_probs < mean_prob
156
+ window_probs[invalid_mask] = 0
157
+
158
+ # normalize within local window
159
+ window_probs_sum = window_probs.sum(dim=1, keepdim=True).to(window_probs.dtype) # [B, 1, H, W]
160
+ window_probs_sum = torch.clamp(window_probs_sum, min=torch.finfo(window_probs_sum.dtype).tiny)
161
+ normalized_probs = window_probs / window_probs_sum # [B, k, H, W]
162
+
163
+ window_disp = sample_idx.to(normalized_probs.dtype) # [B, k, H, W]
164
+
165
+ disp = torch.sum(normalized_probs * window_disp, dim=1).to(normalized_probs.dtype).unsqueeze(-1) # [B, H, W, 1]
166
+
167
+ return disp_to_flow(disp, B)
168
+
169
+ def disp_to_flow(disp, B):
170
+ ## disp[:B//2, ...] = -disp[:B//2, ...] # negetive left flow
171
+
172
+ ## for onnx support
173
+ batch_indices = torch.arange(B, device=disp.device)
174
+ mask = batch_indices < (B // 2)
175
+
176
+ disp = torch.where(mask.view(B, 1, 1, 1), -disp, disp)
177
+
178
+ flow = torch.cat((disp, torch.zeros_like(disp)), dim=-1).contiguous() # [B, H, W, 2]
179
+ return flow
models/mat_pytorch_impl.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def compute_bilinear_weights(grid):
4
+ """
5
+ Compute bilinear weights for BilinearSoftmax
6
+ Args:
7
+ grid: [..., 2], (x, y)
8
+ Returns:
9
+ weights: [..., 4], [nw, ne, sw, se]
10
+ """
11
+ x = grid[..., 0]
12
+ y = grid[..., 1]
13
+
14
+ x0 = torch.floor(x)
15
+ y0 = torch.floor(y)
16
+
17
+ dx = x - x0
18
+ dy = y - y0
19
+
20
+ nw = (1 - dx) * (1 - dy)
21
+ ne = dx * (1 - dy)
22
+ sw = (1 - dx) * dy
23
+ se = dx * dy
24
+
25
+ weights = torch.stack([nw, ne, sw, se], dim=-1)
26
+
27
+ return weights
28
+
29
+ def compute_match_attention(q, k, m_id, win_r, H, W):
30
+ """
31
+ Args:
32
+ q: [B, N, h, C] # Query tensor
33
+ k: [B, N, h, C] # Key tensor
34
+ m_id: [B, N, h, 2] # Sampling centers, last dim is (x, y)
35
+ r: int # Sampling window radius
36
+ H: int # Height
37
+ W: int # Width
38
+
39
+ Returns:
40
+ output: [B, N, h, M] where M = (2*win_r[0]+2)*(2*win_r[1]+2)
41
+ """
42
+ B, N, h, C = q.shape
43
+ M = (2*win_r[0] + 2)*(2*win_r[1] + 2)
44
+
45
+ dx = torch.arange(-win_r[0], win_r[0] + 2, device=q.device, dtype=torch.long)
46
+ dy = torch.arange(-win_r[1], win_r[1] + 2, device=q.device, dtype=torch.long)
47
+ dy, dx = torch.meshgrid(dy, dx, indexing='ij')
48
+ offsets = torch.stack((dx, dy), dim=-1).reshape(M, 2) # [M, 2]
49
+
50
+ centers = m_id.unsqueeze(3) # [B, N, h, 1, 2]
51
+ offsets = offsets.view(1, 1, 1, M, 2) # [1, 1, 1, M, 2]
52
+ coords = centers + offsets # [B, N, h, M, 2]
53
+
54
+ x_coords = coords[..., 0] # [B, N, h, M]
55
+ y_coords = coords[..., 1] # [B, N, h, M]
56
+
57
+ # Clamp coordinates to valid range
58
+ x_coords = x_coords.clamp(0, W-1)
59
+ y_coords = y_coords.clamp(0, H-1)
60
+
61
+ indices = y_coords * W + x_coords # [B, N, h, M]
62
+
63
+ # [B, N, h, C] -> [B, N, h, M, C]
64
+ k_expanded = k.unsqueeze(3).expand(-1, -1, -1, M, -1)
65
+
66
+ # [B, N, h, M] -> [B, N, h, M, C]
67
+ indices_gather = indices.unsqueeze(-1).expand(-1, -1, -1, -1, C)
68
+
69
+ # [B, N, h, M, C]
70
+ k_sampled = torch.gather(k_expanded, dim=1, index=indices_gather)
71
+
72
+ # [B, N, h, M, C] -> [B, N, h, M]
73
+ # negative L1 norm
74
+ output = -torch.abs(q.unsqueeze(3) - k_sampled).sum(dim=-1)
75
+
76
+ return output, indices_gather
77
+
78
+ def attn_scatter(attn, win_r):
79
+ """
80
+ Scatter the attn to four sub-windows
81
+
82
+ Args:
83
+ attn: [B, N, h, M], M = (2*win_r[0]+2) * (2*win_r[1]+2)
84
+ win_r: window radius
85
+
86
+ Returns:
87
+ attn_sub: [B, N, h, 4, M_sub] attn for four sub-windows
88
+ """
89
+ B, N, h, M = attn.shape
90
+ M_sub = (2*win_r[0] + 1)*(2*win_r[1] + 1)
91
+
92
+ # [B, N, h, H_win, W_win]
93
+ attn_2d = attn.view(B, N, h, 2*win_r[0] + 2, 2*win_r[1] + 2)
94
+
95
+ # nw [0, 0] offset
96
+ win_nw = attn_2d[..., :2*win_r[0]+1, :2*win_r[1]+1]
97
+ # ne [1, 0] offset
98
+ win_ne = attn_2d[..., :2*win_r[0]+1, 1:2*win_r[1]+2]
99
+ # sw [0, 1] offset
100
+ win_sw = attn_2d[..., 1:2*win_r[0]+2, :2*win_r[1]+1]
101
+ # se [1, 1] offset
102
+ win_se = attn_2d[..., 1:2*win_r[0]+2, 1:2*win_r[1]+2]
103
+
104
+ win_nw = win_nw.reshape(B, N, h, M_sub)
105
+ win_ne = win_ne.reshape(B, N, h, M_sub)
106
+ win_sw = win_sw.reshape(B, N, h, M_sub)
107
+ win_se = win_se.reshape(B, N, h, M_sub)
108
+
109
+ attn_sub = torch.stack([win_nw, win_ne, win_sw, win_se], dim=3)
110
+
111
+ return attn_sub
112
+
113
+ def attn_gather(attn_sub, win_r):
114
+ """
115
+ Gather the four attn_sub to attn
116
+
117
+ Args:
118
+ attn_sub: [B, N, h, 4, M_sub]
119
+ win_r: window radius
120
+
121
+ Returns:
122
+ merged_attn: [B, N, h, M]
123
+ """
124
+ B, N, h, _, M_sub = attn_sub.shape
125
+
126
+ merged = torch.zeros(B, N, h, 2*win_r[0] + 2, 2*win_r[1] + 2, device=attn_sub.device, dtype=attn_sub.dtype)
127
+
128
+ # nw [0, 0] offset
129
+ win_nw = attn_sub[:, :, :, 0, :].view(B, N, h, 2*win_r[0]+1, 2*win_r[1]+1)
130
+ merged[..., :2*win_r[0]+1, :2*win_r[1]+1] += win_nw
131
+
132
+ # ne [1, 0] offset
133
+ win_ne = attn_sub[:, :, :, 1, :].view(B, N, h, 2*win_r[0]+1, 2*win_r[1]+1)
134
+ merged[..., :2*win_r[0]+1, 1:2*win_r[1]+2] += win_ne
135
+
136
+ # sw [0, 1] offset
137
+ win_sw = attn_sub[:, :, :, 2, :].view(B, N, h, 2*win_r[0]+1, 2*win_r[1]+1)
138
+ merged[..., 1:2*win_r[0]+2, :2*win_r[1]+1] += win_sw
139
+
140
+ # se [1, 1] offset
141
+ win_se = attn_sub[:, :, :, 3, :].view(B, N, h, 2*win_r[0]+1, 2*win_r[1]+1)
142
+ merged[..., 1:2*win_r[0]+2, 1:2*win_r[1]+2] += win_se
143
+
144
+ merged_attn = merged.view(B, N, h, -1)
145
+
146
+ return merged_attn
147
+
148
+ def compute_bilinear_softmax(attn, bilinear_weight, win_r):
149
+ """
150
+ Blinear Softmax: Attention sampled on a contiguous position
151
+
152
+ Args:
153
+ attn: [B, N, h, M] attention on discreate position
154
+ win_r: window radius
155
+
156
+ Returns:
157
+ output: [B, N, h, M] effective attention on contiguous position
158
+ """
159
+ attn_sub = attn_scatter(attn, win_r) # [B, N, h, 4, M_sub]
160
+
161
+ attn_weighted = bilinear_weight.unsqueeze(-1)*attn_sub.softmax(dim=-1)
162
+
163
+ output = attn_gather(attn_weighted, win_r) # [B, N, h, M]
164
+
165
+ return output
166
+
167
+ def attention_aggregate(v, attn, indices_gather, win_r):
168
+
169
+ B, N, h, C = v.shape
170
+ M = (2*win_r[0] + 2)*(2*win_r[1] + 2)
171
+
172
+ # [B, N, h, C] -> [B, N, h, M, C]
173
+ v_expanded = v.unsqueeze(3).expand(-1, -1, -1, M, -1)
174
+ v_sampled = torch.gather(v_expanded, dim=1, index=indices_gather)
175
+
176
+ output = (attn.unsqueeze(-1)*v_sampled).sum(dim=3)
177
+
178
+ return output.view(B, N, -1)
models/match_former_ops.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import List, Tuple
4
+
5
+
6
+ @torch.library.custom_op("match_attention::fused_forward_ops", mutates_args={"output", "attn_out"})
7
+ def fused_forward_ops(
8
+ max_offset: torch.Tensor,
9
+ q: torch.Tensor,
10
+ k: torch.Tensor,
11
+ v: torch.Tensor,
12
+ output: torch.Tensor,
13
+ attn_out: torch.Tensor,
14
+ H: int,
15
+ W: int,
16
+ win_r: List[int],
17
+ attn_num: int,
18
+ attn_type: str,
19
+ scale: float
20
+ ) -> None:
21
+ """
22
+ Opaque custom op for fused forward pass that prevents torch.compile tracing.
23
+
24
+ This wrapper ensures that torch.compile treats this as an opaque operation
25
+ and doesn't try to trace into the CUDA kernel internals.
26
+ """
27
+ # Call the original CUDA extension
28
+ try:
29
+ import match_attention
30
+ match_attention.fused_forward(
31
+ max_offset, q, k, v, output, attn_out,
32
+ H, W, win_r, attn_num, attn_type, scale
33
+ )
34
+ except ImportError:
35
+ # Fallback to torch.ops if direct import fails
36
+ torch.ops.match_attention.fused_forward(
37
+ max_offset, q, k, v, output, attn_out,
38
+ H, W, win_r, attn_num, attn_type, scale
39
+ )
40
+
41
+
42
+ @fused_forward_ops.register_fake
43
+ def _(max_offset, q, k, v, output, attn_out, H, W, win_r, attn_num, attn_type, scale):
44
+ """
45
+ Fake implementation for torch.compile that defines tensor shapes and dtypes
46
+ without actually executing the kernel.
47
+ """
48
+ # Validate input shapes
49
+ B, N, C = q.shape
50
+ h = max_offset.size(2)
51
+
52
+ # Ensure output tensors have correct shapes
53
+ torch._check(output.shape == (B, N, C), lambda: f"output shape mismatch: expected {(B, N, C)}, got {output.shape}")
54
+ torch._check(attn_out.shape == (B, N, h, attn_num), lambda: f"attn_out shape mismatch: expected {(B, N, h, attn_num)}, got {attn_out.shape}")
55
+
56
+ # Ensure output tensors have correct dtypes and devices
57
+ torch._check(output.dtype == q.dtype, lambda: f"output dtype mismatch: expected {q.dtype}, got {output.dtype}")
58
+ torch._check(attn_out.dtype == q.dtype, lambda: f"attn_out dtype mismatch: expected {q.dtype}, got {attn_out.dtype}")
59
+ torch._check(output.device == q.device, lambda: f"output device mismatch: expected {q.device}, got {output.device}")
60
+ torch._check(attn_out.device == q.device, lambda: f"attn_out device mismatch: expected {q.device}, got {attn_out.device}")
61
+
62
+ return None
63
+
64
+
65
+ class MF_FusedForwardOps(nn.Module):
66
+ """
67
+ Opaque MatchAttention fused forward, optimized for torch.compile
68
+
69
+ This version uses torch.library.custom_op to create opaque custom operators,
70
+ preventing torch.compile from tracing into CUDA kernel internals.
71
+ """
72
+
73
+ def __init__(self):
74
+ super().__init__()
75
+
76
+ def forward(
77
+ self,
78
+ max_offset: torch.Tensor,
79
+ q: torch.Tensor,
80
+ k: torch.Tensor,
81
+ v: torch.Tensor,
82
+ H: int,
83
+ W: int,
84
+ win_r: List[int],
85
+ attn_num: int,
86
+ attn_type: str = 'l1_norm',
87
+ scale: float = 1.0
88
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
89
+ """
90
+ Fused forward
91
+
92
+ Args:
93
+ max_offset: Offset tensor with shape [B, N, h, 2]
94
+ q: Query tensor with shape [B, N, C]
95
+ k: Key tensor with shape [B, N, C]
96
+ v: Value tensor with shape [B, N, C]
97
+ H: Feature map height
98
+ W: Feature map width
99
+ win_r: Window radius [r_h, r_w]
100
+ attn_num: Number of attention heads
101
+ attn_type: Attention type ('l1_norm' or 'l2_norm')
102
+ scale: Scale factor
103
+
104
+ Returns:
105
+ output: Output features with shape [B, N, C]
106
+ attn_out: Attention weights with shape [B, N, h, attn_num]
107
+ """
108
+ B, N, C = q.shape
109
+ h = max_offset.size(2)
110
+
111
+ # Create output tensors
112
+ output = torch.zeros_like(v)
113
+ attn_out = q.new_zeros([B, N, h, attn_num])
114
+
115
+ # Call opaque custom operator
116
+ fused_forward_ops(
117
+ max_offset, q, k, v, output, attn_out,
118
+ H, W, win_r, attn_num, attn_type, scale
119
+ )
120
+
121
+ return output, attn_out
models/match_stereo.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from timm.models.layers import trunc_normal_
6
+ from models.common import UpConv
7
+ from models.convformer import convformer
8
+ from models.attention_blocks import MatchAttentionBlock
9
+ from models.cost_volume import GlobalCorrelation
10
+
11
+ class MatchStereo(nn.Module):
12
+ def __init__(self, args,
13
+ refine_win_rs=[2, 2, 1, 1], # refine window radius at 1/32, 1/16, 1/8, 1/4
14
+ refine_nums=[8, 8, 8, 2],
15
+ num_heads=[4, 4, 4, 4],
16
+ mlp_ratios=[2, 2, 2, 2],
17
+ drop_path=0.):
18
+ super().__init__()
19
+ self.refine_nums = refine_nums
20
+
21
+ self.encoder = convformer(args.variant)
22
+ self.channels = self.encoder.dims[::-1] # resolution low to high
23
+ self.num_heads = num_heads
24
+ self.head_dims = [c//h for c, h in zip(self.channels, self.num_heads)]
25
+
26
+ self.factor = 2
27
+ self.factor_last = 2**(len(self.channels) - len(refine_nums) + 2)
28
+
29
+ self.field_dim = 2 # 2(flow)
30
+
31
+ self.up_decoders = nn.ModuleList()
32
+ self.up_masks = nn.ModuleList()
33
+ for i in range(len(self.channels)):
34
+ if i > 0:
35
+ self.up_decoders.append(UpConv(self.channels[i-1], self.channels[i]))
36
+ self.up_masks.append(
37
+ nn.Sequential(
38
+ nn.Conv2d(self.channels[i-1], self.channels[i-1], 3, padding=1),
39
+ nn.ReLU(inplace=True),
40
+ nn.Conv2d(self.channels[i-1], (self.factor**2)*9, 1, padding=0))
41
+ )
42
+ else:
43
+ self.up_decoders.append(nn.Identity())
44
+ self.up_masks.append(nn.Identity())
45
+
46
+ self.up_masks.append(
47
+ nn.Sequential(
48
+ nn.Conv2d(self.channels[-1], self.channels[-1]*2, 3, padding=1),
49
+ nn.ReLU(inplace=True),
50
+ nn.Conv2d(self.channels[-1]*2, (self.factor_last**2)*9, 1, padding=0)))
51
+
52
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path, sum(refine_nums))]
53
+ # MatchAttention
54
+ self.match_attentions = nn.ModuleList()
55
+ for i in range(len(refine_nums)):
56
+ self.match_attentions.append(
57
+ MatchAttentionBlock(args, self.channels[i], win_r=refine_win_rs[i],
58
+ num_layer=refine_nums[i], num_head=self.num_heads[i], head_dim=self.head_dims[i],
59
+ mlp_ratio=mlp_ratios[i], field_dim=self.field_dim,
60
+ dp_rates=dp_rates[sum(refine_nums[:i]):sum(refine_nums[:i+1])])
61
+ )
62
+
63
+ self.init_correlation_volume = GlobalCorrelation(self.channels[0])
64
+
65
+ self.apply(self._init_weights)
66
+
67
+ def _init_weights(self, m):
68
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
69
+ trunc_normal_(m.weight, std=.02)
70
+ if m.bias is not None:
71
+ nn.init.constant_(m.bias, 0)
72
+
73
+ def upsample_field(self, field, mask, factor):
74
+ ''' Upsample field [H/factor, W/factor, D] -> [H, W, D] using convex combination '''
75
+ B, H, W, D = field.shape
76
+ field = field.permute(0, 3, 1, 2)
77
+ mask = mask.view(B, 1, 9, factor, factor, H, W)
78
+ mask = torch.softmax(mask, dim=2).to(mask.dtype)
79
+ up_flow = F.unfold(field*factor, [3,3], padding=1)
80
+ up_flow = up_flow.view(B, D, 9, 1, 1, H, W)
81
+
82
+ up_flow = torch.sum(mask * up_flow, dim=2).to(mask.dtype) # [B, D, 9, factor, factor, H, W]
83
+ up_flow = up_flow.permute(0, 4, 2, 5, 3, 1)
84
+ return up_flow.reshape(B, factor*H, factor*W, D).contiguous()
85
+
86
+ def forward(self, img0, img1, stereo=True, init_flow=None):
87
+ ''' Estimate optical flow/disparity between pair of frames, output bi-directional flow/disparity '''
88
+ field_all = []
89
+
90
+ img0 = (2 * (img0 / 255.0) - 1.0).contiguous()
91
+ img1 = (2 * (img1 / 255.0) - 1.0).contiguous()
92
+
93
+ x = torch.cat((img0, img1), dim=0) # cat in batch dim
94
+
95
+ features = self.encoder(x) # [B*2, H, W, C]
96
+ features = features[::-1] # reverse 1/32, 1/16, 1/8, 1/4
97
+
98
+ for i in range(len(features)): # 1/32, 1/16, 1/8, 1/4
99
+ if i==0:
100
+ if init_flow is None:
101
+ init_flow, init_cv = self.init_correlation_volume(features[i], stereo=stereo)
102
+ else:
103
+ init_cv = None
104
+
105
+ field = init_flow.clone() # [B, H, W, 2]
106
+ self_rpos = torch.zeros_like(field)
107
+ else:
108
+ features[i] = self.up_decoders[i](features[i-1], features[i])
109
+ up_mask = self.up_masks[i](features[i-1].permute(0, 3, 1, 2)) # [B, C, H, W]
110
+ self_rpos = self.upsample_field(self_rpos, up_mask, self.factor)
111
+ field = self.upsample_field(field, up_mask, self.factor)
112
+ field_all.append({'self':field})
113
+
114
+ features[i], self_rpos, field, fields = self.match_attentions[i](features[i], self_rpos, field, stereo=stereo)
115
+ field_all.extend(fields)
116
+
117
+ if self.training:
118
+ B = field.shape[0]
119
+ field_up = self.upsample_field(field[:B//2], self.up_masks[-1](features[-1][:B//2].permute(0, 3, 1, 2)), self.factor_last)
120
+ field_up = torch.cat((field_up, field_up), dim=0) # dummy output
121
+ else:
122
+ field_up = self.upsample_field(field, self.up_masks[-1](features[-1].permute(0, 3, 1, 2)), self.factor_last)
123
+
124
+ return {
125
+ 'init_flow': init_flow,
126
+ 'init_cv': init_cv,
127
+ 'field_all': field_all,
128
+ 'field_up': field_up,
129
+ 'self_rpos': self_rpos,
130
+ }
models/setup.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from setuptools import setup
3
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
4
+
5
+ setup(
6
+ name='match_attention',
7
+ version='0.7',
8
+ description='Match Attention CUDA Extension for PyTorch',
9
+ author='TingmanYan',
10
+ ext_modules=[
11
+ CUDAExtension('match_attention', [
12
+ 'src/match_former_cuda.cpp',
13
+ 'src/match_former_cuda_kernel.cu',
14
+ 'src/match_former_fused_forward.cu',
15
+ ]),
16
+ ],
17
+ cmdclass={
18
+ 'build_ext': BuildExtension
19
+ }
20
+ )
models/src/match_former_cuda.cpp ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <vector>
3
+ #include <pybind11/pybind11.h>
4
+ #include <pybind11/stl.h>
5
+ #include <string>
6
+ #include <ATen/core/op_registration/op_registration.h>
7
+
8
+ // CUDA declarations
9
+
10
+ void mf_fused_forward_cuda(
11
+ at::Tensor max_offset,
12
+ at::Tensor q,
13
+ at::Tensor k,
14
+ at::Tensor v,
15
+ at::Tensor output,
16
+ at::Tensor attn_out,
17
+ const int H,
18
+ const int W,
19
+ const std::vector<int64_t>& win_r,
20
+ const int attn_num,
21
+ const std::string& attn_type,
22
+ const float scale);
23
+
24
+ void mf_fused_forward(
25
+ at::Tensor max_offset,
26
+ at::Tensor q,
27
+ at::Tensor k,
28
+ at::Tensor v,
29
+ at::Tensor output,
30
+ at::Tensor attn_out,
31
+ const int64_t H,
32
+ const int64_t W,
33
+ const std::vector<int64_t>& win_r,
34
+ const int64_t attn_num,
35
+ const std::string& attn_type,
36
+ const double scale)
37
+ {
38
+ mf_fused_forward_cuda(max_offset, q, k, v, output, attn_out, H, W, win_r, attn_num, attn_type, static_cast<float>(scale));
39
+ }
40
+
41
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
42
+ {
43
+ m.def("fused_forward", &mf_fused_forward, "Fused forward pass (CUDA)");
44
+ }
45
+
46
+ TORCH_LIBRARY(match_attention, m)
47
+ {
48
+ m.def("fused_forward(Tensor max_offset, Tensor q, Tensor k, Tensor v, Tensor(a!) output, Tensor(b!) attn_out, int H, int W, int[] win_r, int attn_num, str attn_type, float scale) -> ()", &mf_fused_forward);
49
+ }
models/src/match_former_cuda_kernel.cu ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ #include <cuda.h>
4
+ #include <cuda_runtime.h>
5
+
6
+ #include <vector>
7
+
8
+ #include "match_former_fused_forward.hpp"
9
+
10
+ // Fused forward function that combines all operations
11
+ void mf_fused_forward_cuda(
12
+ at::Tensor max_offset,
13
+ at::Tensor q,
14
+ at::Tensor k,
15
+ at::Tensor v,
16
+ at::Tensor output,
17
+ at::Tensor attn_out,
18
+ const int H,
19
+ const int W,
20
+ const std::vector<int64_t>& win_r,
21
+ const int attn_num,
22
+ const std::string& attn_type,
23
+ const float scale)
24
+ {
25
+ match_former_fused_forward(max_offset, q, k, v, output, attn_out, H, W, win_r, attn_num, attn_type, scale);
26
+ }
models/src/match_former_fused_forward.cu ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <cuda.h>
3
+ #include <cuda_runtime.h>
4
+ #include <vector>
5
+ #include <cassert>
6
+ #include <cfloat>
7
+ #include <cuda_fp16.h>
8
+ #include <cuda_bf16.h>
9
+ #include <ATen/native/cuda/KernelUtils.cuh>
10
+
11
+ // Forward declarations of kernel functions
12
+ template <typename scalar_t>
13
+ __global__ void clip_offset_to_id_k(const scalar_t *const m_offset_d, int *const m_id_d, const int Lh, const int num_heads, const int N, const int H, const int W);
14
+
15
+ template <typename scalar_t>
16
+ __global__ void attn_weight_bilinear_forward_k(const scalar_t* const m_offset_d, scalar_t* const bilinear_weight_d, const int Lh);
17
+
18
+ __global__ void check_max_id_k(int *const m_id_d, const int L, const int N, const int H, const int W, const int num_heads, const int win_x, const int win_y);
19
+
20
+ template <typename scalar_t>
21
+ __global__ void match_attention_l1_norm_forward_k(
22
+ const scalar_t *__restrict__ q_d,
23
+ const scalar_t *__restrict__ k_d,
24
+ scalar_t *__restrict__ attn_d,
25
+ const int *__restrict__ m_id_d,
26
+ const int *__restrict__ offset_d,
27
+ const int L, const int N, const int H, const int W,
28
+ const int C, const int num_heads, const int key_dim,
29
+ const int attn_num, const int attn_numel,
30
+ const bool swap_xy);
31
+
32
+ template <typename scalar_t>
33
+ __global__ void match_attention_dot_product_forward_k(const scalar_t *const q_d, const scalar_t *const k_d, scalar_t *const attn_d, const int *const m_id_d, const int* const offset_d, const int L, const int N, const int H, const int W, const int C, const int num_heads, const int key_dim, const int attn_num, const int attn_numel, const bool swap_xy);
34
+
35
+ template <typename scalar_t>
36
+ __global__ void bilinear_softmax_forward_general_k(scalar_t* const __restrict__ attn_d,
37
+ scalar_t* const __restrict__ attn_out_d,
38
+ scalar_t* const __restrict__ attn_sum_d,
39
+ const scalar_t* const __restrict__ bilinear_weight_d,
40
+ const int* const __restrict__ select_index_d,
41
+ int L, const int num_heads, const int h_attn_num,
42
+ const int attn_num, const int attn_num_sub);
43
+
44
+ template <typename scalar_t>
45
+ __global__ void attention_aggregate_forward_k(
46
+ const scalar_t *__restrict__ v_d,
47
+ scalar_t *__restrict__ out_d,
48
+ const scalar_t *__restrict__ attn_d,
49
+ const int *__restrict__ m_id_d,
50
+ const int* __restrict__ offset_d,
51
+ const int L, const int C, const int num_heads,
52
+ const int key_dim, const int attn_num,
53
+ const bool swap_xy);
54
+
55
+ template <typename scalar_t>
56
+ __global__ void scale_attention_k(scalar_t* attn_d, const scalar_t scale, const int total_size);
57
+
58
+ // Kernel implementations
59
+ template <typename scalar_t>
60
+ __global__ void
61
+ clip_offset_to_id_k(const scalar_t *const m_offset_d, int *const m_id_d, const int Lh, const int num_heads, const int N, const int H, const int W)
62
+ {
63
+ int lh = blockIdx.x * blockDim.x + threadIdx.x;
64
+ if (lh >= Lh)
65
+ return;
66
+
67
+ int l = lh / num_heads;
68
+ int batch_id = l / N;
69
+ int m_x = __float2int_rd(static_cast<float>(m_offset_d[lh*2])); // round to floor
70
+ int m_y = __float2int_rd(static_cast<float>(m_offset_d[lh*2 + 1]));
71
+ if (m_x < 0) m_x = 0;
72
+ if (m_x >= W) m_x = W - 1;
73
+ if (m_y < 0) m_y = 0;
74
+ if (m_y >= H) m_y = H - 1;
75
+ int m_pix_id = m_y * W + m_x;
76
+ int m_id = batch_id * N + m_pix_id;
77
+ m_id_d[lh] = m_id;
78
+ }
79
+
80
+ template <typename scalar_t>
81
+ __global__ void
82
+ attn_weight_bilinear_forward_k(const scalar_t* const m_offset_d, scalar_t* const bilinear_weight_d, const int Lh)
83
+ {
84
+ int lh = blockIdx.x * blockDim.x + threadIdx.x;
85
+ if (lh >= Lh)
86
+ return;
87
+
88
+ float ix = static_cast<float>(m_offset_d[lh*2]);
89
+ float iy = static_cast<float>(m_offset_d[lh*2 + 1]);
90
+ int ix_nw = __float2int_rd(ix);
91
+ int iy_nw = __float2int_rd(iy);
92
+ int ix_ne = ix_nw + 1;
93
+ int iy_ne = iy_nw;
94
+ int ix_sw = ix_nw;
95
+ int iy_sw = iy_nw + 1;
96
+ int ix_se = ix_nw + 1;
97
+ int iy_se = iy_nw + 1;
98
+
99
+ float nw = (ix_se - ix) * (iy_se - iy);
100
+ float ne = (ix - ix_sw) * (iy_sw - iy);
101
+ float sw = (ix_ne - ix) * (iy - iy_ne);
102
+ float se = (ix - ix_nw) * (iy - iy_nw);
103
+ bilinear_weight_d[lh*4] = static_cast<scalar_t>(nw);
104
+ bilinear_weight_d[lh*4 + 1] = static_cast<scalar_t>(ne);
105
+ bilinear_weight_d[lh*4 + 2] = static_cast<scalar_t>(sw);
106
+ bilinear_weight_d[lh*4 + 3] = static_cast<scalar_t>(se); // bilinear_weight of shape [B, N, h, 4]
107
+ }
108
+
109
+ // check if the search window range is out of image coordinates
110
+ __forceinline__ __device__ void
111
+ check_within_image_coordinates(int& l_id, const int& N, const int& H, const int& W, const int& win_x, const int& win_y)
112
+ {
113
+ int pix_id = l_id % N;
114
+ int batch_id = l_id / N;
115
+ int x = pix_id % W;
116
+ int y = pix_id / W;
117
+ if (x - win_x < 0)
118
+ x = win_x;
119
+ if (x + (win_x + 1) >= W)
120
+ x = W - 1 - (win_x + 1);
121
+ if (y - win_y < 0)
122
+ y = win_y;
123
+ if (y + (win_y + 1) >= H)
124
+ y = H - 1 - (win_y + 1);
125
+ pix_id = y * W + x;
126
+ l_id = batch_id * N + pix_id;
127
+ }
128
+
129
+ __global__ void
130
+ check_max_id_k(int *const m_id_d, const int L, const int N, const int H, const int W, const int num_heads, const int win_x, const int win_y)
131
+ {
132
+ int l, h;
133
+ l = blockIdx.x * blockDim.x + threadIdx.x;
134
+ h = blockIdx.y * blockDim.y + threadIdx.y;
135
+ if (l >= L || h >= num_heads)
136
+ return;
137
+
138
+ int m_id = m_id_d[l * num_heads + h];
139
+ check_within_image_coordinates(m_id, N, H, W, win_x, win_y);
140
+ m_id_d[l * num_heads + h] = m_id;
141
+ }
142
+
143
+ template <typename scalar_t>
144
+ __global__ void match_attention_l1_norm_forward_k(
145
+ const scalar_t *__restrict__ q_d,
146
+ const scalar_t *__restrict__ k_d,
147
+ scalar_t *__restrict__ attn_d,
148
+ const int *__restrict__ m_id_d,
149
+ const int *__restrict__ offset_d,
150
+ const int L, const int N, const int H, const int W,
151
+ const int C, const int num_heads, const int key_dim,
152
+ const int attn_num, const int attn_numel,
153
+ const bool swap_xy)
154
+ {
155
+ int l, k;
156
+ if (swap_xy)
157
+ {
158
+ l = blockIdx.x * blockDim.x + threadIdx.x;
159
+ k = blockIdx.y * blockDim.y + threadIdx.y;
160
+ }
161
+ else
162
+ {
163
+ k = blockIdx.x * blockDim.x + threadIdx.x;
164
+ l = blockIdx.y * blockDim.y + threadIdx.y;
165
+ }
166
+ if (l >= L || k >= num_heads*attn_num)
167
+ return;
168
+
169
+ constexpr int vec_size = sizeof(float4) / sizeof(scalar_t);
170
+ const int h = k / attn_num;
171
+ const int attn_id = k % attn_num;
172
+ const int base_id = l*num_heads + h;
173
+ const int base_attn_id = base_id*attn_num;
174
+ const int key_id = m_id_d[base_id] + offset_d[attn_id];
175
+
176
+ const int q_base = l * C;
177
+ const int k_base = key_id * C;
178
+ const int c_start = h * key_dim / vec_size;
179
+ const int c_end = c_start + key_dim / vec_size;
180
+
181
+ const float4* q_val_vec = reinterpret_cast<const float4*>(q_d + q_base);
182
+ const float4* k_val_vec = reinterpret_cast<const float4*>(k_d + k_base);
183
+
184
+ float diff_sum = 0.0f;
185
+
186
+ for (int c = c_start; c < c_end; ++c) {
187
+ float4 q_val_f4 = __ldg(&q_val_vec[c]);
188
+ float4 k_val_f4 = __ldg(&k_val_vec[c]);
189
+
190
+ if (vec_size == 4) { // float32
191
+ diff_sum += fabsf(q_val_f4.x - k_val_f4.x) +
192
+ fabsf(q_val_f4.y - k_val_f4.y) +
193
+ fabsf(q_val_f4.z - k_val_f4.z) +
194
+ fabsf(q_val_f4.w - k_val_f4.w);
195
+ } else { // bf16/fp16 (8 elements)
196
+ if (std::is_same<scalar_t, at::Half>::value) {
197
+ const half2* q_val_h2 = reinterpret_cast<const half2*>(&q_val_f4);
198
+ const half2* k_val_h2 = reinterpret_cast<const half2*>(&k_val_f4);
199
+ #pragma unroll
200
+ for (int i = 0; i < 4; ++i) {
201
+ half2 q_h2 = q_val_h2[i];
202
+ half2 k_h2 = k_val_h2[i];
203
+ half2 diff_h2 = __habs2(__hsub2(q_h2, k_h2));
204
+ diff_sum += __half2float(diff_h2.x) + __half2float(diff_h2.y);
205
+ }
206
+ } else { // bf16
207
+ const __nv_bfloat162* q_val_bf2 = reinterpret_cast<const __nv_bfloat162*>(&q_val_f4);
208
+ const __nv_bfloat162* k_val_bf2 = reinterpret_cast<const __nv_bfloat162*>(&k_val_f4);
209
+ #pragma unroll
210
+ for (int i = 0; i < 4; ++i) {
211
+ __nv_bfloat162 q_bf2 = q_val_bf2[i];
212
+ __nv_bfloat162 k_bf2 = k_val_bf2[i];
213
+ __nv_bfloat162 diff_bf2 = __habs2(__hsub2(q_bf2, k_bf2));
214
+ diff_sum += __bfloat162float(diff_bf2.x) + __bfloat162float(diff_bf2.y);
215
+ }
216
+ }
217
+ }
218
+ }
219
+ attn_d[base_attn_id + attn_id] = static_cast<scalar_t>(-diff_sum);
220
+ }
221
+
222
+ template <typename scalar_t>
223
+ __global__ void
224
+ match_attention_dot_product_forward_k(const scalar_t *const q_d, const scalar_t *const k_d, scalar_t *const attn_d, const int *const m_id_d, const int* const offset_d, const int L, const int N, const int H, const int W, const int C, const int num_heads, const int key_dim, const int attn_num, const int attn_numel, const bool swap_xy)
225
+ {
226
+ int l, k;
227
+ if (swap_xy)
228
+ {
229
+ l = blockIdx.x * blockDim.x + threadIdx.x;
230
+ k = blockIdx.y * blockDim.y + threadIdx.y;
231
+ }
232
+ else
233
+ {
234
+ k = blockIdx.x * blockDim.x + threadIdx.x;
235
+ l = blockIdx.y * blockDim.y + threadIdx.y;
236
+ }
237
+ if (l >= L || k >= num_heads*attn_num)
238
+ return;
239
+
240
+ int h = k / attn_num;
241
+ int attn_id = k % attn_num;
242
+ int base_id = l*num_heads + h;
243
+ int base_attn_id = base_id*attn_num;
244
+ int key_id = m_id_d[base_id] + offset_d[attn_id];
245
+ scalar_t diff_sum = 0;
246
+ for (int c = h * key_dim; c < (h + 1) * key_dim; ++c)
247
+ {
248
+ diff_sum += q_d[l * C + c] * k_d[key_id * C + c];
249
+ }
250
+ attn_d[base_attn_id + attn_id] = diff_sum;
251
+ }
252
+
253
+ template <typename scalar_t>
254
+ __global__ void scale_attention_k(scalar_t* attn_d, const scalar_t scale, const int total_size)
255
+ {
256
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
257
+ if (idx >= total_size)
258
+ return;
259
+ attn_d[idx] = attn_d[idx] * scale;
260
+ }
261
+
262
+ template <typename T> struct VecType { using Type = T; };
263
+ template <> struct VecType<float> { using Type = float4; };
264
+ template <> struct VecType<__half> { using Type = float2; };
265
+ template <> struct VecType<__nv_bfloat16> { using Type = float2; };
266
+
267
+ template <typename scalar_t>
268
+ __device__ __inline__ typename VecType<scalar_t>::Type load_vec(const scalar_t* addr) {
269
+ return *reinterpret_cast<const typename VecType<scalar_t>::Type*>(addr);
270
+ }
271
+
272
+ template <typename scalar_t>
273
+ __device__ __inline__ void store_vec(scalar_t* addr, typename VecType<scalar_t>::Type val) {
274
+ *reinterpret_cast<typename VecType<scalar_t>::Type*>(addr) = val;
275
+ }
276
+
277
+ template <int WIN_SIZE, typename scalar_t>
278
+ __device__ __forceinline__ void load_window(scalar_t* window, const scalar_t* src) {
279
+ constexpr int VEC_ELEMS = sizeof(typename VecType<scalar_t>::Type) / sizeof(scalar_t);
280
+ constexpr int VEC_COUNT = WIN_SIZE / VEC_ELEMS;
281
+ using vec_t = typename VecType<scalar_t>::Type;
282
+
283
+ #pragma unroll 4
284
+ for (int i = 0; i < VEC_COUNT; ++i) {
285
+ vec_t vec = load_vec<scalar_t>(src + i * VEC_ELEMS);
286
+ store_vec<scalar_t>(window + i * VEC_ELEMS, vec);
287
+ }
288
+ }
289
+
290
+ template <int WIN_SIZE, typename scalar_t>
291
+ __device__ __forceinline__ void store_window(scalar_t* dst, const scalar_t* window) {
292
+ constexpr int VEC_ELEMS = sizeof(typename VecType<scalar_t>::Type) / sizeof(scalar_t);
293
+ constexpr int VEC_COUNT = WIN_SIZE / VEC_ELEMS;
294
+ using vec_t = typename VecType<scalar_t>::Type;
295
+
296
+ #pragma unroll 4
297
+ for (int i = 0; i < VEC_COUNT; ++i) {
298
+ vec_t vec = load_vec<scalar_t>(window + i * VEC_ELEMS);
299
+ store_vec<scalar_t>(dst + i * VEC_ELEMS, vec);
300
+ }
301
+ }
302
+
303
+ template <int WIN_SIZE, int SUB_WIN_SIZE, typename scalar_t>
304
+ __global__ void
305
+ bilinear_softmax_forward_k(scalar_t* const __restrict__ attn_d,
306
+ scalar_t* const __restrict__ attn_out_d,
307
+ scalar_t* const __restrict__ attn_sum_d,
308
+ const scalar_t* const __restrict__ bilinear_weight_d,
309
+ const int* const __restrict__ select_index_d,
310
+ int L, const int num_heads, const int h_attn_num,
311
+ const int attn_num)
312
+ {
313
+ constexpr int VEC_ELEMS = sizeof(typename VecType<scalar_t>::Type) / sizeof(scalar_t);
314
+ static_assert(WIN_SIZE % VEC_ELEMS == 0, "WIN_SIZE must be divisible by vector elements");
315
+ using acc_t = float;
316
+
317
+ int l = blockIdx.x * blockDim.x + threadIdx.x;
318
+ int h = blockIdx.y * blockDim.y + threadIdx.y;
319
+ if (l >= L || h >= num_heads)
320
+ return;
321
+
322
+ const int base_attn_id = l * h_attn_num + h * attn_num;
323
+ const int base_sum_idx = l * (num_heads * 4) + h * 4;
324
+
325
+ scalar_t window[WIN_SIZE];
326
+ load_window<WIN_SIZE>(window, attn_d + base_attn_id);
327
+
328
+ acc_t attn_max = -FLT_MAX;
329
+ #pragma unroll 4
330
+ for (int k = 0; k < WIN_SIZE; ++k) {
331
+ if (static_cast<acc_t>(window[k]) > attn_max) {
332
+ attn_max = static_cast<acc_t>(window[k]);
333
+ }
334
+ }
335
+
336
+ #pragma unroll 4
337
+ for (int k = 0; k < WIN_SIZE; ++k) {
338
+ window[k] = static_cast<scalar_t>(expf(static_cast<acc_t>(window[k]) - attn_max));
339
+ }
340
+
341
+ scalar_t window_out[WIN_SIZE] = {0};
342
+
343
+ for (int b = 0; b < 4; ++b) {
344
+ acc_t block_sum = 0.0f;
345
+ const int* block_idx = select_index_d + b * SUB_WIN_SIZE;
346
+
347
+ #pragma unroll 4
348
+ for (int k = 0; k < SUB_WIN_SIZE; ++k) {
349
+ block_sum += static_cast<acc_t>(window[block_idx[k]]);
350
+ }
351
+ block_sum = fmaxf(block_sum, FLT_EPSILON);
352
+ attn_sum_d[base_sum_idx + b] = static_cast<scalar_t>(block_sum);
353
+
354
+ const scalar_t weight = bilinear_weight_d[base_sum_idx + b];
355
+ const scalar_t scale = static_cast<scalar_t>(static_cast<acc_t>(weight) / block_sum);
356
+
357
+ #pragma unroll 4
358
+ for (int k = 0; k < SUB_WIN_SIZE; ++k) {
359
+ const int idx = block_idx[k];
360
+ window_out[idx] = window_out[idx] + window[idx] * scale;
361
+ }
362
+ }
363
+
364
+ // write back to global memory
365
+ store_window<WIN_SIZE>(attn_out_d + base_attn_id, window_out);
366
+ }
367
+
368
+ template <typename scalar_t>
369
+ __global__ void
370
+ bilinear_softmax_forward_general_k(scalar_t* const __restrict__ attn_d,
371
+ scalar_t* const __restrict__ attn_out_d,
372
+ scalar_t* const __restrict__ attn_sum_d,
373
+ const scalar_t* const __restrict__ bilinear_weight_d,
374
+ const int* const __restrict__ select_index_d,
375
+ int L, const int num_heads, const int h_attn_num,
376
+ const int attn_num, const int attn_num_sub)
377
+ {
378
+ int l, h;
379
+ l = blockIdx.x * blockDim.x + threadIdx.x;
380
+ h = blockIdx.y * blockDim.y + threadIdx.y;
381
+ if (l >= L || h >= num_heads)
382
+ return;
383
+
384
+ scalar_t attn_max = -FLT_MAX;
385
+ int base_attn_id = l * h_attn_num + h * attn_num;
386
+ for (int k = 0; k < attn_num; ++k)
387
+ {
388
+ scalar_t attn_val = attn_d[base_attn_id + k];
389
+ if (attn_val > attn_max) {
390
+ attn_max = attn_val;
391
+ }
392
+ }
393
+ __syncthreads();
394
+
395
+ for (int k = 0; k < attn_num; ++k)
396
+ {
397
+ attn_d[base_attn_id + k] = expf(attn_d[base_attn_id + k] - attn_max);
398
+ }
399
+ __syncthreads();
400
+
401
+ for (int b = 0; b < 4; ++b)
402
+ {
403
+ scalar_t attn_sum = 0;
404
+ for (int k = 0; k < attn_num_sub; ++k)
405
+ {
406
+ attn_sum += attn_d[base_attn_id + select_index_d[b*attn_num_sub + k]];
407
+ }
408
+ attn_sum = fmaxf(attn_sum, FLT_EPSILON);
409
+ attn_sum_d[l*(num_heads*4) + h*4 + b] = attn_sum; // save for backward
410
+
411
+ scalar_t weight = bilinear_weight_d[l*num_heads*4 + h*4 + b];
412
+ for (int k = 0; k < attn_num_sub; ++k)
413
+ {
414
+ int select_index = select_index_d[b*attn_num_sub + k];
415
+ attn_out_d[base_attn_id + select_index] +=
416
+ attn_d[base_attn_id + select_index] / attn_sum * weight; // no write conflict
417
+ }
418
+ }
419
+ }
420
+
421
+ template <typename scalar_t>
422
+ __global__ void attention_aggregate_forward_k(
423
+ const scalar_t *__restrict__ v_d,
424
+ scalar_t *__restrict__ out_d,
425
+ const scalar_t *__restrict__ attn_d,
426
+ const int *__restrict__ m_id_d,
427
+ const int* __restrict__ offset_d,
428
+ const int L, const int C, const int num_heads,
429
+ const int key_dim, const int attn_num,
430
+ const bool swap_xy)
431
+ {
432
+ int c, l;
433
+ if (swap_xy)
434
+ {
435
+ l = blockIdx.x * blockDim.x + threadIdx.x;
436
+ c = blockIdx.y * blockDim.y + threadIdx.y;
437
+ }
438
+ else
439
+ {
440
+ c = blockIdx.x * blockDim.x + threadIdx.x;
441
+ l = blockIdx.y * blockDim.y + threadIdx.y;
442
+ }
443
+ if (l >= L || c >= C)
444
+ return;
445
+
446
+ const int h = c / key_dim;
447
+ const int base_id = l*num_heads + h;
448
+ const int base_attn_id = base_id*attn_num;
449
+ const int m_id = m_id_d[base_id];
450
+ float out_sum = 0;
451
+ for (int k = 0; k < attn_num; ++k)
452
+ {
453
+ int key_id = m_id + offset_d[k];
454
+ out_sum += static_cast<float>(attn_d[base_attn_id + k]) *
455
+ static_cast<float>(v_d[key_id * C + c]);
456
+ }
457
+ out_d[l * C + c] = static_cast<scalar_t>(out_sum);
458
+ }
459
+
460
+ // Main fused forward function
461
+ void match_former_fused_forward(
462
+ at::Tensor max_offset,
463
+ at::Tensor q,
464
+ at::Tensor k,
465
+ at::Tensor v,
466
+ at::Tensor output,
467
+ at::Tensor attn_out,
468
+ const int H,
469
+ const int W,
470
+ const std::vector<int64_t>& win_r,
471
+ const int attn_num,
472
+ const std::string& attn_type,
473
+ const float scale)
474
+ {
475
+ const int B = q.size(0);
476
+ const int N = q.size(1);
477
+ const int C = q.size(2);
478
+ const int h = max_offset.size(2);
479
+ const int key_dim = C / h;
480
+ const int L = B * N;
481
+ const int Lh = L * h;
482
+ const int attn_numel = L * h * attn_num;
483
+ const int win_x = win_r[0];
484
+ const int win_y = win_r[1];
485
+ assert(attn_num == (2*win_r[0]+2)*(2*win_r[1]+2));
486
+ const bool swap_xy_match = (h * attn_num < 32);
487
+ const bool swap_xy_agg = (C < 32);
488
+ const int attn_num_sub = (2*win_r[0] + 1)*(2*win_r[1] + 1);
489
+ const int h_attn_num = h * attn_num;
490
+
491
+ // Create temporary tensors
492
+ auto m_id = at::zeros({B, N, h}, at::TensorOptions().dtype(at::kInt).device(max_offset.device()));
493
+ auto bilinear_weight = at::zeros({B, N, h, 4}, max_offset.options());
494
+ auto attn = at::zeros({B, N, h, attn_num}, q.options());
495
+ auto attn_sum = at::zeros({B, N, h, 4}, q.options());
496
+
497
+ // Create offset array for window
498
+ int *offset_d;
499
+ cudaMalloc(&offset_d, sizeof(int) * attn_num);
500
+ int *offset_h = new int[attn_num];
501
+ int num = 0;
502
+ for (int y = -win_y; y <= (win_y + 1); ++y)
503
+ for (int x = -win_x; x <= (win_x + 1); ++x)
504
+ {
505
+ offset_h[num++] = y * W + x;
506
+ }
507
+ cudaMemcpy(offset_d, offset_h, sizeof(int) * attn_num, cudaMemcpyHostToDevice);
508
+ delete[] offset_h;
509
+
510
+ // Create select_index array for bilinear softmax
511
+ int *select_index_d;
512
+ cudaMalloc(&select_index_d, sizeof(int)*4*attn_num_sub);
513
+ int *select_index_h = new int[4*attn_num_sub];
514
+ int win_W = 2*(win_r[0]+1);
515
+ int delta_x[4] = {0, 1, 0, 1};
516
+ int delta_y[4] = {0, 0, 1, 1};
517
+ num = 0;
518
+ for (int b = 0; b < 4; ++b) {
519
+ int d_x = delta_x[b];
520
+ int d_y = delta_y[b];
521
+ for (int y = d_y; y <= 2*win_r[1] + d_y; ++y)
522
+ for (int x = d_x; x <= 2*win_r[0] + d_x; ++x)
523
+ {
524
+ select_index_h[num++] = y * win_W + x;
525
+ }
526
+ }
527
+ cudaMemcpy(select_index_d, select_index_h, sizeof(int)*attn_num_sub*4, cudaMemcpyHostToDevice);
528
+ delete[] select_index_h;
529
+
530
+ // Step 1: Clip offset to id
531
+ {
532
+ int grid = (Lh + 512 - 1) / 512;
533
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, max_offset.scalar_type(), "clip_offset_to_id_k", ([&] {
534
+ clip_offset_to_id_k<scalar_t><<<grid, 512>>>(max_offset.data_ptr<scalar_t>(), m_id.data_ptr<int>(), Lh, h, N, H, W);
535
+ }));
536
+ }
537
+
538
+ // Step 2: Compute bilinear weights
539
+ {
540
+ int grid = (Lh + 512 - 1) / 512;
541
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, max_offset.scalar_type(), "attn_weight_bilinear_forward_k", ([&] {
542
+ attn_weight_bilinear_forward_k<scalar_t><<<grid, 512>>>(max_offset.data_ptr<scalar_t>(), bilinear_weight.data_ptr<scalar_t>(), Lh);
543
+ }));
544
+ }
545
+
546
+ // Step 3: Check max id bounds
547
+ {
548
+ dim3 m_blocks(8, 128);
549
+ dim3 grids((L + m_blocks.x - 1) / m_blocks.x, (h + m_blocks.y - 1) / m_blocks.y);
550
+ check_max_id_k<<<grids, m_blocks>>>(m_id.data_ptr<int>(), L, N, H, W, h, win_x, win_y);
551
+ }
552
+
553
+ // Step 4: Compute attention
554
+ {
555
+ dim3 m_blocks(8, 128);
556
+ dim3 grids((h*attn_num + m_blocks.x - 1) / m_blocks.x, (L + m_blocks.y - 1) / m_blocks.y);
557
+ if (swap_xy_match)
558
+ grids = dim3((L + m_blocks.x - 1) / m_blocks.x, (h*attn_num + m_blocks.y - 1) / m_blocks.y);
559
+
560
+ if (attn_type == "dot_product") {
561
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(q.scalar_type(), "match_attention_dot_product_forward_k", ([&] {
562
+ match_attention_dot_product_forward_k<scalar_t><<<grids, m_blocks>>>(q.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), attn.data_ptr<scalar_t>(), m_id.data_ptr<int>(), offset_d, L, N, H, W, C, h, key_dim, attn_num, attn_numel, swap_xy_match);
563
+ }));
564
+ } else if (attn_type == "l1_norm") {
565
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "match_attention_l1_norm_forward_k", ([&] {
566
+ match_attention_l1_norm_forward_k<scalar_t><<<grids, m_blocks>>>(q.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), attn.data_ptr<scalar_t>(), m_id.data_ptr<int>(), offset_d, L, N, H, W, C, h, key_dim, attn_num, attn_numel, swap_xy_match);
567
+ }));
568
+ }
569
+ }
570
+
571
+ // Step 5: Scale attention
572
+ {
573
+ int grid = (attn_numel + 512 - 1) / 512;
574
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "scale_attention_k", ([&] {
575
+ scale_attention_k<scalar_t><<<grid, 512>>>(attn.data_ptr<scalar_t>(), static_cast<scalar_t>(scale), attn_numel);
576
+ }));
577
+ }
578
+
579
+ // Step 6: Bilinear softmax
580
+ {
581
+ dim3 m_blocks = (attn_num == 16) ? dim3(128, 4) : dim3(32, 4);
582
+ dim3 grids((L + m_blocks.x - 1) / m_blocks.x, (h + m_blocks.y - 1) / m_blocks.y);
583
+
584
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attn.scalar_type(), "bilinear_softmax_forward", [&] {
585
+ if (attn_num == 16 && attn_num_sub == 9) {
586
+ bilinear_softmax_forward_k<16, 9><<<grids, m_blocks>>>(
587
+ attn.data_ptr<scalar_t>(),
588
+ attn_out.data_ptr<scalar_t>(),
589
+ attn_sum.data_ptr<scalar_t>(),
590
+ bilinear_weight.data_ptr<scalar_t>(),
591
+ select_index_d, L, h, h_attn_num, attn_num
592
+ );
593
+ } else if (attn_num == 36 && attn_num_sub == 25) {
594
+ bilinear_softmax_forward_k<36, 25><<<grids, m_blocks>>>(
595
+ attn.data_ptr<scalar_t>(),
596
+ attn_out.data_ptr<scalar_t>(),
597
+ attn_sum.data_ptr<scalar_t>(),
598
+ bilinear_weight.data_ptr<scalar_t>(),
599
+ select_index_d, L, h, h_attn_num, attn_num
600
+ );
601
+ } else {
602
+ bilinear_softmax_forward_general_k<<<grids, m_blocks>>>(
603
+ attn.data_ptr<scalar_t>(),
604
+ attn_out.data_ptr<scalar_t>(),
605
+ attn_sum.data_ptr<scalar_t>(),
606
+ bilinear_weight.data_ptr<scalar_t>(),
607
+ select_index_d, L, h, h_attn_num, attn_num, attn_num_sub
608
+ );
609
+ }
610
+ });
611
+ }
612
+
613
+ // Step 7: Attention aggregation
614
+ {
615
+ dim3 m_blocks = (attn_num == 16) ? dim3(8, 128) : dim3(8, 32);
616
+ dim3 grids((C + m_blocks.x - 1) / m_blocks.x, (L + m_blocks.y - 1) / m_blocks.y);
617
+ if (swap_xy_agg)
618
+ grids = dim3((L + m_blocks.x - 1) / m_blocks.x, (C + m_blocks.y - 1) / m_blocks.y);
619
+
620
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, v.scalar_type(), "attention_aggregate_forward_k", ([&] {
621
+ attention_aggregate_forward_k<scalar_t><<<grids, m_blocks>>>(v.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(), attn_out.data_ptr<scalar_t>(), m_id.data_ptr<int>(), offset_d, L, C, h, key_dim, attn_num, swap_xy_agg);
622
+ }));
623
+ }
624
+
625
+ // Cleanup
626
+ cudaFree(offset_d);
627
+ cudaFree(select_index_d);
628
+ }
models/src/match_former_fused_forward.hpp ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _MATCH_FORMER_FUSED_FORWARD_HPP_
2
+ #define _MATCH_FORMER_FUSED_FORWARD_HPP_
3
+
4
+ #include <vector>
5
+ #include <string>
6
+
7
+ // Fused forward function that combines all match former operations
8
+ void match_former_fused_forward(
9
+ at::Tensor max_offset,
10
+ at::Tensor q,
11
+ at::Tensor k,
12
+ at::Tensor v,
13
+ at::Tensor output,
14
+ at::Tensor attn_out,
15
+ const int H,
16
+ const int W,
17
+ const std::vector<int64_t>& win_r,
18
+ const int attn_num,
19
+ const std::string& attn_type,
20
+ const float scale);
21
+
22
+ #endif
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ imageio==2.9.0
2
+ imageio-ffmpeg==0.4.9
3
+ matplotlib==3.8.4
4
+ opencv-python==4.9.0.80
5
+ pillow==10.2.0
6
+ scikit-image==0.20.0
7
+ scipy==1.9.1
8
+ tensorboard==2.17.0
9
+ setuptools==59.5.0
10
+ psutil==6.0.0
11
+ joblib==1.4.2
12
+ numpy==1.24.4
13
+ tqdm==4.66.2
14
+ timm==0.6.11
utils/file_io.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import numpy as np
6
+ import sys
7
+
8
+ def write_pfm(file, image, scale=1):
9
+ file = open(file, 'wb')
10
+
11
+ color = None
12
+
13
+ if image.dtype.name != 'float32':
14
+ raise Exception('Image dtype must be float32.')
15
+
16
+ image = np.flipud(image)
17
+
18
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
19
+ color = True
20
+ elif len(image.shape) == 2 or len(
21
+ image.shape) == 3 and image.shape[2] == 1: # greyscale
22
+ color = False
23
+ else:
24
+ raise Exception(
25
+ 'Image must have H x W x 3, H x W x 1 or H x W dimensions.')
26
+
27
+ file.write(b'PF\n' if color else b'Pf\n')
28
+ file.write(b'%d %d\n' % (image.shape[1], image.shape[0]))
29
+
30
+ endian = image.dtype.byteorder
31
+
32
+ if endian == '<' or endian == '=' and sys.byteorder == 'little':
33
+ scale = -scale
34
+
35
+ file.write(b'%f\n' % scale)
36
+
37
+ image.tofile(file)
utils/utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+
5
+
6
+ class InputPadder:
7
+ """ Pads images such that dimensions are divisible by padding_factor """
8
+
9
+ def __init__(self, dims, mode='top_right', padding_factor=32):
10
+ self.ht, self.wd = dims[-2:]
11
+ pad_ht = (((self.ht // padding_factor) + 1) * padding_factor - self.ht) % padding_factor
12
+ pad_wd = (((self.wd // padding_factor) + 1) * padding_factor - self.wd) % padding_factor
13
+ if mode == 'sintel':
14
+ self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
15
+ elif mode == 'top_right':
16
+ self._pad = [0, pad_wd, pad_ht, 0]
17
+ elif mode == 'bottom_right':
18
+ self._pad = [0, pad_wd, 0, pad_ht]
19
+ else:
20
+ self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
21
+
22
+ def pad(self, *inputs):
23
+ return [F.pad(x, self._pad, mode='replicate') for x in inputs]
24
+
25
+ def unpad(self, x):
26
+ ht, wd = x.shape[-2:]
27
+ c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
28
+ return x[..., c[0]:c[1], c[2]:c[3]]
29
+
30
+
31
+ def init_coords(ref):
32
+ B, H, W, C = ref.shape
33
+
34
+ coords = torch.meshgrid(torch.arange(H, device=ref.device, dtype=ref.dtype), torch.arange(W, device=ref.device, dtype=ref.dtype), indexing='ij')
35
+ coords = torch.stack(coords[::-1], dim=-1)
36
+ return coords[None].repeat(B, 1, 1, 1).to(ref.device) # [B, H, W, 2]
37
+
38
+
39
+ def bilinear_sample_by_offset(tgt, offset): # tgt [B, _, H, W], offset [B, H, W, 2]
40
+ _, _, H, W = tgt.shape
41
+
42
+ xgrid, ygrid = offset.split([1, 1], dim=-1)
43
+ xgrid = 2*xgrid/(W-1) - 1
44
+ ygrid = 2*ygrid/(H-1) - 1
45
+ grid = torch.cat([xgrid, ygrid], dim=-1)
46
+
47
+ tgt_to_ref = F.grid_sample(tgt, grid, mode='bilinear', align_corners=True)
48
+ return tgt_to_ref
49
+
50
+ def calc_noc_mask(field, A=2):
51
+ offset = field + init_coords(field) # [B, H, W, 2]
52
+ field_ref_, field_tgt_ = field.chunk(2, dim=0)
53
+ field_ref = torch.cat((field_ref_, field_tgt_), dim=0) # order
54
+ field_tgt = torch.cat((field_tgt_, field_ref_), dim=0) # reverse order
55
+ field_tgt_to_ref = bilinear_sample_by_offset(field_tgt.permute(0, 3, 1, 2).contiguous(), offset).permute(0, 2, 3, 1).contiguous()
56
+ field_diff = torch.abs(field_ref + field_tgt_to_ref).sum(dim=-1) # ref and tgt flow has different sign
57
+ noc_mask = (field_diff < A).to(field_diff.dtype)
58
+ return noc_mask