import torch.nn as nn from .trm import * class _MultiHeadAttention(nn.Module): def __init__(self, d_k, d_v, d_model, n_heads, dropout): super(_MultiHeadAttention, self).__init__() self.d_k = d_k self.d_v = d_v self.d_model = d_model self.n_heads = n_heads self.w_q = Linear(d_model, d_k * n_heads) self.w_k = Linear(d_model, d_k * n_heads) self.w_v = Linear(d_model, d_v * n_heads) def forward(self, q, k, v): # q: [b_size x len_q x d_model] # k: [b_size x len_k x d_model] # v: [b_size x len_k x d_model] b_size = q.size(0) # q_s: [b_size x n_heads x len_q x d_k] # k_s: [b_size x n_heads x len_k x d_k] # v_s: [b_size x n_heads x len_k x d_v] q_s = self.w_q(q).view(b_size, -1, self.n_heads, self.d_k).transpose(1, 2) k_s = self.w_k(k).view(b_size, -1, self.n_heads, self.d_k).transpose(1, 2) v_s = self.w_v(v).view(b_size, -1, self.n_heads, self.d_v).transpose(1, 2) return q_s, k_s, v_s class PoswiseFeedForwardNet(nn.Module): def __init__(self, d_model, d_ff, dropout=0.1): super(PoswiseFeedForwardNet, self).__init__() self.relu = nn.ReLU() self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) self.dropout = nn.Dropout(dropout) self.layer_norm = LayerNormalization(d_model) def forward(self, inputs): # inputs: [b_size x len_q x d_model] residual = inputs output = self.relu(self.conv1(inputs.transpose(1, 2))) # outputs: [b_size x len_q x d_model] output = self.conv2(output).transpose(1, 2) output = self.dropout(output) return self.layer_norm(residual + output) class MultiHeadAttention(nn.Module): def __init__(self, d_k, d_v, n_heads, dropout, d_model, visual_len, sen_len, fea_v, fea_s, pos): super(MultiHeadAttention, self).__init__() self.n_heads = n_heads self.multihead_attn_v = _MultiHeadAttention(d_k, d_v, d_model, n_heads, dropout) self.multihead_attn_s = _MultiHeadAttention(d_k, d_v, d_model, n_heads, dropout) self.pos_emb_v = PosEncoding(visual_len * 10, d_model) self.pos_emb_s = PosEncoding(sen_len * 10, d_model) self.linear_v = nn.Linear(in_features=fea_v, out_features=d_model) self.linear_s = nn.Linear(in_features=fea_s, out_features=d_model) self.proj_v = Linear(n_heads * d_v, d_model) self.proj_s = Linear(n_heads * d_v, d_model) self.d_v = d_v self.dropout = nn.Dropout(dropout) self.layer_norm_v = LayerNormalization(d_model) self.layer_norm_s = LayerNormalization(d_model) self.attention = ScaledDotProductAttention(d_k, dropout) self.pos = pos def forward(self, v, s, v_len, s_len): b_size = v.size(0) # q: [b_size x len_q x d_model] # k: [b_size x len_k x d_model] # v: [b_size x len_v x d_model] note (len_k == len_v) v, s = self.linear_v(v), self.linear_s(s) if self.pos: pos_v, pos_s = self.pos_emb_v(v_len), self.pos_emb_s(s_len) residual_v, residual_s = v + pos_v, s + pos_s else: residual_v, residual_s = v, s # context: a tensor of shape [b_size x len_q x n_heads * d_v] q_v, k_v, v_v = self.multihead_attn_v(v, v, v) q_s, k_s, v_s = self.multihead_attn_s(s, s, s) context_v, attn_v = self.attention(q_v, k_s, v_s) context_s, attn_s = self.attention(q_s, k_v, v_v) context_v = context_v.transpose(1, 2).contiguous().view(b_size, -1, self.n_heads * self.d_v) context_s = context_s.transpose(1, 2).contiguous().view(b_size, -1, self.n_heads * self.d_v) # project back to the residual size, outputs: [b_size x len_q x d_model] output_v = self.dropout(self.proj_v(context_v)) output_s = self.dropout(self.proj_s(context_s)) return self.layer_norm_v(residual_v + output_v), self.layer_norm_s(residual_s + output_s) class co_attention(nn.Module): def __init__(self, d_k, d_v, n_heads, dropout, d_model, visual_len, sen_len, fea_v, fea_s, pos): super(co_attention, self).__init__() # self.layer_num = layer_num # self.multi_head = MultiHeadAttention(d_k=d_k, d_v=d_v, n_heads=n_heads, dropout=dropout, d_model=d_model, # visual_len=visual_len, sen_len=sen_len, fea_v=fea_v, fea_s=fea_s, pos=False) # self.PoswiseFeedForwardNet_v = nn.ModuleList([PoswiseFeedForwardNet(d_model=d_model, d_ff=256)]) # self.PoswiseFeedForwardNet_s = nn.ModuleList([PoswiseFeedForwardNet(d_model=d_model, d_ff=256)]) # self.multi_head = nn.ModuleList([MultiHeadAttention(d_k=d_k, d_v=d_v, n_heads=n_heads, dropout=dropout, d_model=d_model, # visual_len=visual_len, sen_len=sen_len, fea_v=fea_v, fea_s=fea_s, pos=False)]) # for i in range(1, layer_num): # self.PoswiseFeedForwardNet_v.append(PoswiseFeedForwardNet(d_model=d_model, d_ff=256)) # self.PoswiseFeedForwardNet_s.append(PoswiseFeedForwardNet(d_model=d_model, d_ff=256)) # self.multi_head.append(MultiHeadAttention(d_k=d_k, d_v=d_v, n_heads=n_heads, dropout=dropout, d_model=d_model, # visual_len=visual_len, sen_len=sen_len, fea_v=d_model, fea_s=d_model, pos=True)) self.multi_head = MultiHeadAttention(d_k=d_k, d_v=d_v, n_heads=n_heads, dropout=dropout, d_model=d_model, visual_len=visual_len, sen_len=sen_len, fea_v=fea_v, fea_s=fea_s, pos=pos) self.PoswiseFeedForwardNet_v = PoswiseFeedForwardNet(d_model=d_model, d_ff=128, dropout=dropout) self.PoswiseFeedForwardNet_s = PoswiseFeedForwardNet(d_model=d_model, d_ff=128,dropout=dropout) def forward(self, v, s, v_len, s_len): # for i in range(self.layer_num): # v, s = self.multi_head[i](v, s, v_len, s_len) # v = self.PoswiseFeedForwardNet_v[i](v) # s = self.PoswiseFeedForwardNet_s[i](s) v, s = self.multi_head(v, s, v_len, s_len) v = self.PoswiseFeedForwardNet_v(v) s = self.PoswiseFeedForwardNet_s(s) return v, s