File size: 5,675 Bytes
c00ff2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from torch import nn

class MultiHeadCoAttention(nn.Module):
    def __init__(self, multi_dim, single_dim, num_heads):
        assert multi_dim % num_heads == 0, 'multi_dim must be divisible by num_heads'
        assert single_dim % num_heads == 0, 'single_dim must be divisible by num_heads'
        super().__init__()
        self.q_proj = nn.Linear(single_dim, single_dim)
        self.k_proj = nn.Linear(single_dim, single_dim)
        self.multi_v_proj = nn.Linear(multi_dim, multi_dim)  # D'
        self.single_v_proj = nn.Linear(single_dim, single_dim)  # D

        self.multi_out_proj = nn.Linear(multi_dim, multi_dim)  # D'
        self.single_out_proj = nn.Linear(single_dim, single_dim)  # D

        self.multi_dim = multi_dim
        self.single_dim = single_dim
        self.num_heads = num_heads

    def forward(self, query, key, multi_value, single_value):
        # q, k, multi_v: (T,B,ch,D')
        # single_v: (T,B,1,D)
        query = torch.transpose(query, 0, 1)  # (B,T,ch,D')...[32, 150, 4, 64]
        key = torch.transpose(key, 0, 1)  # (B,T,ch,D')...[32, 150, 4, 64]
        multi_value = torch.permute(multi_value, (1, 2, 0, 3))  # (B,ch,T,D')...[32, 4, 150, 64]
        single_value = torch.permute(single_value, (1, 2, 0, 3))  # (B,1,T,D)...[32, 1, 150, 256]
        ###########

        q = torch.split(self.q_proj(query), self.single_dim // self.num_heads, dim=-1)  # seq: (B,T,ch,D'/h)
        q = torch.stack(q, dim=1)  # (B,h,T,ch,D'/h)...[32, 8, 150, 4, 8]

        k = torch.split(self.k_proj(key), self.single_dim // self.num_heads, dim=-1)  # seq: (B,T,ch,D'/h)
        k = torch.stack(k, dim=1)  # (B,h,T,ch,D'/h)...[32, 8, 150, 4, 8]

        multi_v = torch.split(self.multi_v_proj(multi_value), self.multi_dim // self.num_heads,
                              dim=-1)  # seq: (B,ch,T,D'/h)
        multi_v = torch.stack(multi_v, dim=1)  # (B, h, ch, T, D'/h)...[32, 8, 4, 150, 8]

        single_v = torch.split(self.single_v_proj(single_value), self.single_dim // self.num_heads,
                               dim=-1)  # seq: (B,1,T,D/h)
        single_v = torch.stack(single_v, dim=1)  # seq: (B,h,1,T,D/h)...[32, 32, 1, 150, 8]

        q = q.view(*q.shape[:-2], -1)  # (B, h, T, ch*D/h)
        k = k.view(*k.shape[:-2], -1)  # (B, h, T, ch*D/h)
        normalizer = torch.sqrt(torch.Tensor([float(q.shape[-1])]).to(q.device))

        sim_mat = torch.matmul(q, torch.transpose(k, -2, -1)) / normalizer  # (B, h, T, T)
        att_mat = torch.unsqueeze(nn.functional.softmax(sim_mat, dim=-1), 2)  # (B, h, 1, T, T)

        # co-attention
        multi_result = torch.matmul(att_mat, multi_v)  # (B, h, ch, T, D'/h)
        single_result = torch.matmul(att_mat, single_v)  # (B, h, 1, T, D/h)

        multi_result = torch.permute(multi_result, (3, 0, 2, 1, 4))  # (T, B, ch, h, D'/h)
        single_result = torch.permute(single_result, (3, 0, 2, 1, 4))  # (T, B, 1, h, D/h)
        multi_result = torch.reshape(multi_result, multi_result.shape[:-2] + (-1,))  # (T, B, ch, D')
        single_result = torch.reshape(single_result, single_result.shape[:-2] + (-1,))  # (T, B, 1, D)

        multi_result = self.multi_out_proj(multi_result)
        single_result = self.single_out_proj(single_result)
        return multi_result, single_result


class CoAttention(nn.Module):
    def __init__(self, embed_dim=768, single_dim=256, multi_dim=64, n_heads=8, attn_dropout=0.,
                 init_mult=1e-2):  # , pre_norm=True):
        super().__init__()
        self.init_mult = init_mult

        self.in_single_proj = nn.Linear(embed_dim, single_dim)  # single_dim == D
        self.in_single_ln = nn.LayerNorm(single_dim)

        self.in_multi_proj = nn.Linear(embed_dim, multi_dim)  # multi_dim == D'
        self.in_multi_ln = nn.LayerNorm(multi_dim)

        self.mca = MultiHeadCoAttention(multi_dim, single_dim, n_heads)
        self.mca_multi_out_ln = nn.LayerNorm(multi_dim)
        self.mca_single_out_ln = nn.LayerNorm(single_dim)

        # default MHA input: (seq, batch, feature)
        self.cross_frame_mha = nn.MultiheadAttention(single_dim, n_heads, dropout=attn_dropout, bias=True, kdim=None,
                                                     vdim=None)
        self.mha_ln = nn.LayerNorm(single_dim)

        self.cat_proj = nn.Linear(single_dim + multi_dim, embed_dim)

        self.miso = False

    def scale_weights(self):
        self.cat_proj.bias.data *= 0.
        self.cat_proj.weight.data *= self.init_mult

    def forward(self, x):
        # x: (T,B,ch,F); (150, 32, 4, 768)
        frames, B, chans, feat_dim = x.shape

        single_x = torch.mean(x,dim=2)  # (T,B,F)
        single_x = self.in_single_ln(self.in_single_proj(single_x)).unsqueeze(dim=-2)  # (T,B,1,D)

        multi_x = self.in_multi_ln(self.in_multi_proj(x))  # (T,B,ch,D')

        # MCA
        multi_mca, single_mca = self.mca(single_x, single_x, multi_x, single_x)  # (T,B,ch,D'), (T,B,ch,D)
        single_x = single_x + single_mca
        multi_x = multi_x + multi_mca
        multi_x = self.mca_multi_out_ln(multi_x)  # (T,B,ch,D')
        single_x = torch.squeeze(self.mca_single_out_ln(single_x), -2)  # (T,B,D)

        # MHA
        single_mha, _ = self.cross_frame_mha(single_x, single_x, single_x, need_weights=False)  # (T, B, D)
        single_x = self.mha_ln(single_mha + single_x)

        # join representations
        single_x = single_x.unsqueeze(-2)  # (T,B,1,D)
        single_x_tile = torch.tile(single_x, (1, 1, chans, 1))  # (T,B,ch,D)
        cat_x = torch.cat([single_x_tile, multi_x], dim=-1)  # (T,B,ch,D+D')
        out = self.cat_proj(cat_x)  # (T,B,ch,F)

        return out