File size: 5,409 Bytes
0940df6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import torch

def compute_bilinear_weights(grid):
    """
    Compute bilinear weights for BilinearSoftmax
    Args:
        grid: [..., 2], (x, y)
    Returns:
        weights: [..., 4], [nw, ne, sw, se] 
    """
    x = grid[..., 0]
    y = grid[..., 1]
    
    x0 = torch.floor(x)
    y0 = torch.floor(y)
    
    dx = x - x0
    dy = y - y0
    
    nw = (1 - dx) * (1 - dy)
    ne = dx * (1 - dy)      
    sw = (1 - dx) * dy      
    se = dx * dy            
    
    weights = torch.stack([nw, ne, sw, se], dim=-1)
    
    return weights

def compute_match_attention(q, k, m_id, win_r, H, W):
    """
    Args:
        q: [B, N, h, C]   # Query tensor
        k: [B, N, h, C]   # Key tensor
        m_id: [B, N, h, 2] # Sampling centers, last dim is (x, y)
        r: int             # Sampling window radius
        H: int             # Height
        W: int             # Width
    
    Returns:
        output: [B, N, h, M] where M = (2*win_r[0]+2)*(2*win_r[1]+2)
    """
    B, N, h, C = q.shape
    M = (2*win_r[0] + 2)*(2*win_r[1] + 2)
    
    dx = torch.arange(-win_r[0], win_r[0] + 2, device=q.device, dtype=torch.long)
    dy = torch.arange(-win_r[1], win_r[1] + 2, device=q.device, dtype=torch.long)
    dy, dx = torch.meshgrid(dy, dx, indexing='ij')
    offsets = torch.stack((dx, dy), dim=-1).reshape(M, 2)  # [M, 2]
    
    centers = m_id.unsqueeze(3)          # [B, N, h, 1, 2]
    offsets = offsets.view(1, 1, 1, M, 2) # [1, 1, 1, M, 2]
    coords = centers + offsets            # [B, N, h, M, 2]
    
    x_coords = coords[..., 0]  # [B, N, h, M]
    y_coords = coords[..., 1]  # [B, N, h, M]
    
    # Clamp coordinates to valid range
    x_coords = x_coords.clamp(0, W-1)
    y_coords = y_coords.clamp(0, H-1)
    
    indices = y_coords * W + x_coords  # [B, N, h, M]
    
    # [B, N, h, C] -> [B, N, h, M, C]
    k_expanded = k.unsqueeze(3).expand(-1, -1, -1, M, -1)
    
    # [B, N, h, M] -> [B, N, h, M, C]
    indices_gather = indices.unsqueeze(-1).expand(-1, -1, -1, -1, C)
    
    # [B, N, h, M, C]
    k_sampled = torch.gather(k_expanded, dim=1, index=indices_gather)
    
    # [B, N, h, M, C] -> [B, N, h, M]
    # negative L1 norm
    output = -torch.abs(q.unsqueeze(3) - k_sampled).sum(dim=-1)
    
    return output, indices_gather

def attn_scatter(attn, win_r):
    """
    Scatter the attn to four sub-windows
    
    Args:
        attn: [B, N, h, M], M = (2*win_r[0]+2) * (2*win_r[1]+2)
        win_r: window radius
        
    Returns:
        attn_sub: [B, N, h, 4, M_sub] attn for four sub-windows
    """
    B, N, h, M = attn.shape
    M_sub = (2*win_r[0] + 1)*(2*win_r[1] + 1)
    
    # [B, N, h, H_win, W_win]
    attn_2d = attn.view(B, N, h, 2*win_r[0] + 2, 2*win_r[1] + 2)
    
    # nw [0, 0] offset
    win_nw = attn_2d[..., :2*win_r[0]+1, :2*win_r[1]+1]
    # ne [1, 0] offset
    win_ne = attn_2d[..., :2*win_r[0]+1, 1:2*win_r[1]+2]
    # sw [0, 1] offset
    win_sw = attn_2d[..., 1:2*win_r[0]+2, :2*win_r[1]+1]
    # se [1, 1] offset
    win_se = attn_2d[..., 1:2*win_r[0]+2, 1:2*win_r[1]+2]
    
    win_nw = win_nw.reshape(B, N, h, M_sub)
    win_ne = win_ne.reshape(B, N, h, M_sub)
    win_sw = win_sw.reshape(B, N, h, M_sub)
    win_se = win_se.reshape(B, N, h, M_sub)
    
    attn_sub = torch.stack([win_nw, win_ne, win_sw, win_se], dim=3)
    
    return attn_sub

def attn_gather(attn_sub, win_r):
    """
    Gather the four attn_sub to attn
    
    Args:
        attn_sub: [B, N, h, 4, M_sub] 
        win_r: window radius
        
    Returns:
        merged_attn: [B, N, h, M]
    """
    B, N, h, _, M_sub = attn_sub.shape
    
    merged = torch.zeros(B, N, h, 2*win_r[0] + 2, 2*win_r[1] + 2, device=attn_sub.device, dtype=attn_sub.dtype)
    
    # nw [0, 0] offset
    win_nw = attn_sub[:, :, :, 0, :].view(B, N, h, 2*win_r[0]+1, 2*win_r[1]+1)
    merged[..., :2*win_r[0]+1, :2*win_r[1]+1] += win_nw
    
    # ne [1, 0] offset
    win_ne = attn_sub[:, :, :, 1, :].view(B, N, h, 2*win_r[0]+1, 2*win_r[1]+1)
    merged[..., :2*win_r[0]+1, 1:2*win_r[1]+2] += win_ne
    
    # sw [0, 1] offset
    win_sw = attn_sub[:, :, :, 2, :].view(B, N, h, 2*win_r[0]+1, 2*win_r[1]+1)
    merged[..., 1:2*win_r[0]+2, :2*win_r[1]+1] += win_sw
    
    # se [1, 1] offset
    win_se = attn_sub[:, :, :, 3, :].view(B, N, h, 2*win_r[0]+1, 2*win_r[1]+1)
    merged[..., 1:2*win_r[0]+2, 1:2*win_r[1]+2] += win_se
    
    merged_attn = merged.view(B, N, h, -1)
    
    return merged_attn

def compute_bilinear_softmax(attn, bilinear_weight, win_r):
    """
    Blinear Softmax: Attention sampled on a contiguous position
    
    Args:
        attn: [B, N, h, M] attention on discreate position
        win_r: window radius
        
    Returns:
        output: [B, N, h, M] effective attention on contiguous position
    """
    attn_sub = attn_scatter(attn, win_r) # [B, N, h, 4, M_sub]
    
    attn_weighted = bilinear_weight.unsqueeze(-1)*attn_sub.softmax(dim=-1)
    
    output = attn_gather(attn_weighted, win_r) # [B, N, h, M]
    
    return output

def attention_aggregate(v, attn, indices_gather, win_r):

    B, N, h, C = v.shape
    M = (2*win_r[0] + 2)*(2*win_r[1] + 2)

    # [B, N, h, C] -> [B, N, h, M, C]
    v_expanded = v.unsqueeze(3).expand(-1, -1, -1, M, -1)
    v_sampled = torch.gather(v_expanded, dim=1, index=indices_gather)

    output = (attn.unsqueeze(-1)*v_sampled).sum(dim=3)

    return output.view(B, N, -1)