import torch def compute_bilinear_weights(grid): """ Compute bilinear weights for BilinearSoftmax Args: grid: [..., 2], (x, y) Returns: weights: [..., 4], [nw, ne, sw, se] """ x = grid[..., 0] y = grid[..., 1] x0 = torch.floor(x) y0 = torch.floor(y) dx = x - x0 dy = y - y0 nw = (1 - dx) * (1 - dy) ne = dx * (1 - dy) sw = (1 - dx) * dy se = dx * dy weights = torch.stack([nw, ne, sw, se], dim=-1) return weights def compute_match_attention(q, k, m_id, win_r, H, W): """ Args: q: [B, N, h, C] # Query tensor k: [B, N, h, C] # Key tensor m_id: [B, N, h, 2] # Sampling centers, last dim is (x, y) r: int # Sampling window radius H: int # Height W: int # Width Returns: output: [B, N, h, M] where M = (2*win_r[0]+2)*(2*win_r[1]+2) """ B, N, h, C = q.shape M = (2*win_r[0] + 2)*(2*win_r[1] + 2) dx = torch.arange(-win_r[0], win_r[0] + 2, device=q.device, dtype=torch.long) dy = torch.arange(-win_r[1], win_r[1] + 2, device=q.device, dtype=torch.long) dy, dx = torch.meshgrid(dy, dx, indexing='ij') offsets = torch.stack((dx, dy), dim=-1).reshape(M, 2) # [M, 2] centers = m_id.unsqueeze(3) # [B, N, h, 1, 2] offsets = offsets.view(1, 1, 1, M, 2) # [1, 1, 1, M, 2] coords = centers + offsets # [B, N, h, M, 2] x_coords = coords[..., 0] # [B, N, h, M] y_coords = coords[..., 1] # [B, N, h, M] # Clamp coordinates to valid range x_coords = x_coords.clamp(0, W-1) y_coords = y_coords.clamp(0, H-1) indices = y_coords * W + x_coords # [B, N, h, M] # [B, N, h, C] -> [B, N, h, M, C] k_expanded = k.unsqueeze(3).expand(-1, -1, -1, M, -1) # [B, N, h, M] -> [B, N, h, M, C] indices_gather = indices.unsqueeze(-1).expand(-1, -1, -1, -1, C) # [B, N, h, M, C] k_sampled = torch.gather(k_expanded, dim=1, index=indices_gather) # [B, N, h, M, C] -> [B, N, h, M] # negative L1 norm output = -torch.abs(q.unsqueeze(3) - k_sampled).sum(dim=-1) return output, indices_gather def attn_scatter(attn, win_r): """ Scatter the attn to four sub-windows Args: attn: [B, N, h, M], M = (2*win_r[0]+2) * (2*win_r[1]+2) win_r: window radius Returns: attn_sub: [B, N, h, 4, M_sub] attn for four sub-windows """ B, N, h, M = attn.shape M_sub = (2*win_r[0] + 1)*(2*win_r[1] + 1) # [B, N, h, H_win, W_win] attn_2d = attn.view(B, N, h, 2*win_r[0] + 2, 2*win_r[1] + 2) # nw [0, 0] offset win_nw = attn_2d[..., :2*win_r[0]+1, :2*win_r[1]+1] # ne [1, 0] offset win_ne = attn_2d[..., :2*win_r[0]+1, 1:2*win_r[1]+2] # sw [0, 1] offset win_sw = attn_2d[..., 1:2*win_r[0]+2, :2*win_r[1]+1] # se [1, 1] offset win_se = attn_2d[..., 1:2*win_r[0]+2, 1:2*win_r[1]+2] win_nw = win_nw.reshape(B, N, h, M_sub) win_ne = win_ne.reshape(B, N, h, M_sub) win_sw = win_sw.reshape(B, N, h, M_sub) win_se = win_se.reshape(B, N, h, M_sub) attn_sub = torch.stack([win_nw, win_ne, win_sw, win_se], dim=3) return attn_sub def attn_gather(attn_sub, win_r): """ Gather the four attn_sub to attn Args: attn_sub: [B, N, h, 4, M_sub] win_r: window radius Returns: merged_attn: [B, N, h, M] """ B, N, h, _, M_sub = attn_sub.shape merged = torch.zeros(B, N, h, 2*win_r[0] + 2, 2*win_r[1] + 2, device=attn_sub.device, dtype=attn_sub.dtype) # nw [0, 0] offset win_nw = attn_sub[:, :, :, 0, :].view(B, N, h, 2*win_r[0]+1, 2*win_r[1]+1) merged[..., :2*win_r[0]+1, :2*win_r[1]+1] += win_nw # ne [1, 0] offset win_ne = attn_sub[:, :, :, 1, :].view(B, N, h, 2*win_r[0]+1, 2*win_r[1]+1) merged[..., :2*win_r[0]+1, 1:2*win_r[1]+2] += win_ne # sw [0, 1] offset win_sw = attn_sub[:, :, :, 2, :].view(B, N, h, 2*win_r[0]+1, 2*win_r[1]+1) merged[..., 1:2*win_r[0]+2, :2*win_r[1]+1] += win_sw # se [1, 1] offset win_se = attn_sub[:, :, :, 3, :].view(B, N, h, 2*win_r[0]+1, 2*win_r[1]+1) merged[..., 1:2*win_r[0]+2, 1:2*win_r[1]+2] += win_se merged_attn = merged.view(B, N, h, -1) return merged_attn def compute_bilinear_softmax(attn, bilinear_weight, win_r): """ Blinear Softmax: Attention sampled on a contiguous position Args: attn: [B, N, h, M] attention on discreate position win_r: window radius Returns: output: [B, N, h, M] effective attention on contiguous position """ attn_sub = attn_scatter(attn, win_r) # [B, N, h, 4, M_sub] attn_weighted = bilinear_weight.unsqueeze(-1)*attn_sub.softmax(dim=-1) output = attn_gather(attn_weighted, win_r) # [B, N, h, M] return output def attention_aggregate(v, attn, indices_gather, win_r): B, N, h, C = v.shape M = (2*win_r[0] + 2)*(2*win_r[1] + 2) # [B, N, h, C] -> [B, N, h, M, C] v_expanded = v.unsqueeze(3).expand(-1, -1, -1, M, -1) v_sampled = torch.gather(v_expanded, dim=1, index=indices_gather) output = (attn.unsqueeze(-1)*v_sampled).sum(dim=3) return output.view(B, N, -1)