Spaces:
Build error
Build error
| import torch.nn as nn | |
| from .trm import * | |
| class _MultiHeadAttention(nn.Module): | |
| def __init__(self, d_k, d_v, d_model, n_heads, dropout): | |
| super(_MultiHeadAttention, self).__init__() | |
| self.d_k = d_k | |
| self.d_v = d_v | |
| self.d_model = d_model | |
| self.n_heads = n_heads | |
| self.w_q = Linear(d_model, d_k * n_heads) | |
| self.w_k = Linear(d_model, d_k * n_heads) | |
| self.w_v = Linear(d_model, d_v * n_heads) | |
| def forward(self, q, k, v): | |
| # q: [b_size x len_q x d_model] | |
| # k: [b_size x len_k x d_model] | |
| # v: [b_size x len_k x d_model] | |
| b_size = q.size(0) | |
| # q_s: [b_size x n_heads x len_q x d_k] | |
| # k_s: [b_size x n_heads x len_k x d_k] | |
| # v_s: [b_size x n_heads x len_k x d_v] | |
| q_s = self.w_q(q).view(b_size, -1, self.n_heads, self.d_k).transpose(1, 2) | |
| k_s = self.w_k(k).view(b_size, -1, self.n_heads, self.d_k).transpose(1, 2) | |
| v_s = self.w_v(v).view(b_size, -1, self.n_heads, self.d_v).transpose(1, 2) | |
| return q_s, k_s, v_s | |
| class PoswiseFeedForwardNet(nn.Module): | |
| def __init__(self, d_model, d_ff, dropout=0.1): | |
| super(PoswiseFeedForwardNet, self).__init__() | |
| self.relu = nn.ReLU() | |
| self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) | |
| self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) | |
| self.dropout = nn.Dropout(dropout) | |
| self.layer_norm = LayerNormalization(d_model) | |
| def forward(self, inputs): | |
| # inputs: [b_size x len_q x d_model] | |
| residual = inputs | |
| output = self.relu(self.conv1(inputs.transpose(1, 2))) | |
| # outputs: [b_size x len_q x d_model] | |
| output = self.conv2(output).transpose(1, 2) | |
| output = self.dropout(output) | |
| return self.layer_norm(residual + output) | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, d_k, d_v, n_heads, dropout, d_model, visual_len, sen_len, fea_v, fea_s, pos): | |
| super(MultiHeadAttention, self).__init__() | |
| self.n_heads = n_heads | |
| self.multihead_attn_v = _MultiHeadAttention(d_k, d_v, d_model, n_heads, dropout) | |
| self.multihead_attn_s = _MultiHeadAttention(d_k, d_v, d_model, n_heads, dropout) | |
| self.pos_emb_v = PosEncoding(visual_len * 10, d_model) | |
| self.pos_emb_s = PosEncoding(sen_len * 10, d_model) | |
| self.linear_v = nn.Linear(in_features=fea_v, out_features=d_model) | |
| self.linear_s = nn.Linear(in_features=fea_s, out_features=d_model) | |
| self.proj_v = Linear(n_heads * d_v, d_model) | |
| self.proj_s = Linear(n_heads * d_v, d_model) | |
| self.d_v = d_v | |
| self.dropout = nn.Dropout(dropout) | |
| self.layer_norm_v = LayerNormalization(d_model) | |
| self.layer_norm_s = LayerNormalization(d_model) | |
| self.attention = ScaledDotProductAttention(d_k, dropout) | |
| self.pos = pos | |
| def forward(self, v, s, v_len, s_len): | |
| b_size = v.size(0) | |
| # q: [b_size x len_q x d_model] | |
| # k: [b_size x len_k x d_model] | |
| # v: [b_size x len_v x d_model] note (len_k == len_v) | |
| v, s = self.linear_v(v), self.linear_s(s) | |
| if self.pos: | |
| pos_v, pos_s = self.pos_emb_v(v_len), self.pos_emb_s(s_len) | |
| residual_v, residual_s = v + pos_v, s + pos_s | |
| else: | |
| residual_v, residual_s = v, s | |
| # context: a tensor of shape [b_size x len_q x n_heads * d_v] | |
| q_v, k_v, v_v = self.multihead_attn_v(v, v, v) | |
| q_s, k_s, v_s = self.multihead_attn_s(s, s, s) | |
| context_v, attn_v = self.attention(q_v, k_s, v_s) | |
| context_s, attn_s = self.attention(q_s, k_v, v_v) | |
| context_v = context_v.transpose(1, 2).contiguous().view(b_size, -1, self.n_heads * self.d_v) | |
| context_s = context_s.transpose(1, 2).contiguous().view(b_size, -1, self.n_heads * self.d_v) | |
| # project back to the residual size, outputs: [b_size x len_q x d_model] | |
| output_v = self.dropout(self.proj_v(context_v)) | |
| output_s = self.dropout(self.proj_s(context_s)) | |
| return self.layer_norm_v(residual_v + output_v), self.layer_norm_s(residual_s + output_s) | |
| class co_attention(nn.Module): | |
| def __init__(self, d_k, d_v, n_heads, dropout, d_model, visual_len, sen_len, fea_v, fea_s, pos): | |
| super(co_attention, self).__init__() | |
| # self.layer_num = layer_num | |
| # self.multi_head = MultiHeadAttention(d_k=d_k, d_v=d_v, n_heads=n_heads, dropout=dropout, d_model=d_model, | |
| # visual_len=visual_len, sen_len=sen_len, fea_v=fea_v, fea_s=fea_s, pos=False) | |
| # self.PoswiseFeedForwardNet_v = nn.ModuleList([PoswiseFeedForwardNet(d_model=d_model, d_ff=256)]) | |
| # self.PoswiseFeedForwardNet_s = nn.ModuleList([PoswiseFeedForwardNet(d_model=d_model, d_ff=256)]) | |
| # self.multi_head = nn.ModuleList([MultiHeadAttention(d_k=d_k, d_v=d_v, n_heads=n_heads, dropout=dropout, d_model=d_model, | |
| # visual_len=visual_len, sen_len=sen_len, fea_v=fea_v, fea_s=fea_s, pos=False)]) | |
| # for i in range(1, layer_num): | |
| # self.PoswiseFeedForwardNet_v.append(PoswiseFeedForwardNet(d_model=d_model, d_ff=256)) | |
| # self.PoswiseFeedForwardNet_s.append(PoswiseFeedForwardNet(d_model=d_model, d_ff=256)) | |
| # self.multi_head.append(MultiHeadAttention(d_k=d_k, d_v=d_v, n_heads=n_heads, dropout=dropout, d_model=d_model, | |
| # visual_len=visual_len, sen_len=sen_len, fea_v=d_model, fea_s=d_model, pos=True)) | |
| self.multi_head = MultiHeadAttention(d_k=d_k, d_v=d_v, n_heads=n_heads, dropout=dropout, d_model=d_model, | |
| visual_len=visual_len, sen_len=sen_len, fea_v=fea_v, fea_s=fea_s, pos=pos) | |
| self.PoswiseFeedForwardNet_v = PoswiseFeedForwardNet(d_model=d_model, d_ff=128, dropout=dropout) | |
| self.PoswiseFeedForwardNet_s = PoswiseFeedForwardNet(d_model=d_model, d_ff=128,dropout=dropout) | |
| def forward(self, v, s, v_len, s_len): | |
| # for i in range(self.layer_num): | |
| # v, s = self.multi_head[i](v, s, v_len, s_len) | |
| # v = self.PoswiseFeedForwardNet_v[i](v) | |
| # s = self.PoswiseFeedForwardNet_s[i](s) | |
| v, s = self.multi_head(v, s, v_len, s_len) | |
| v = self.PoswiseFeedForwardNet_v(v) | |
| s = self.PoswiseFeedForwardNet_s(s) | |
| return v, s | |