|
|
"""This code is refer from: |
|
|
https://github.com/jjwei66/BUSNet |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from .nrtr_decoder import PositionalEncoding, TransformerBlock |
|
|
from .abinet_decoder import _get_mask, _get_length |
|
|
|
|
|
|
|
|
class BUSDecoder(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
in_channels, |
|
|
out_channels, |
|
|
nhead=8, |
|
|
num_layers=4, |
|
|
dim_feedforward=2048, |
|
|
dropout=0.1, |
|
|
max_length=25, |
|
|
ignore_index=100, |
|
|
pretraining=False, |
|
|
detach=True): |
|
|
super().__init__() |
|
|
d_model = in_channels |
|
|
self.ignore_index = ignore_index |
|
|
self.pretraining = pretraining |
|
|
self.d_model = d_model |
|
|
self.detach = detach |
|
|
self.max_length = max_length + 1 |
|
|
self.out_channels = out_channels |
|
|
|
|
|
|
|
|
self.proj = nn.Linear(out_channels, d_model, False) |
|
|
self.token_encoder = PositionalEncoding(dropout=0.1, |
|
|
dim=d_model, |
|
|
max_len=self.max_length) |
|
|
self.pos_encoder = PositionalEncoding(dropout=0.1, |
|
|
dim=d_model, |
|
|
max_len=self.max_length) |
|
|
|
|
|
self.decoder = nn.ModuleList([ |
|
|
TransformerBlock( |
|
|
d_model=d_model, |
|
|
nhead=nhead, |
|
|
dim_feedforward=dim_feedforward, |
|
|
attention_dropout_rate=dropout, |
|
|
residual_dropout_rate=dropout, |
|
|
with_self_attn=False, |
|
|
with_cross_attn=True, |
|
|
) for i in range(num_layers) |
|
|
]) |
|
|
|
|
|
v_mask = torch.empty((1, 1, d_model)) |
|
|
l_mask = torch.empty((1, 1, d_model)) |
|
|
self.v_mask = nn.Parameter(v_mask) |
|
|
self.l_mask = nn.Parameter(l_mask) |
|
|
torch.nn.init.uniform_(self.v_mask, -0.001, 0.001) |
|
|
torch.nn.init.uniform_(self.l_mask, -0.001, 0.001) |
|
|
|
|
|
v_embeding = torch.empty((1, 1, d_model)) |
|
|
l_embeding = torch.empty((1, 1, d_model)) |
|
|
self.v_embeding = nn.Parameter(v_embeding) |
|
|
self.l_embeding = nn.Parameter(l_embeding) |
|
|
torch.nn.init.uniform_(self.v_embeding, -0.001, 0.001) |
|
|
torch.nn.init.uniform_(self.l_embeding, -0.001, 0.001) |
|
|
self.cls = nn.Linear(d_model, out_channels) |
|
|
|
|
|
def forward_decoder(self, q, x, mask=None): |
|
|
for decoder_layer in self.decoder: |
|
|
q = decoder_layer(q, x, cross_mask=mask) |
|
|
output = q |
|
|
logits = self.cls(output) |
|
|
return logits |
|
|
|
|
|
def forward(self, img_feat, data=None): |
|
|
""" |
|
|
Args: |
|
|
tokens: (N, T, C) where T is length, N is batch size and C is classes number |
|
|
lengths: (N,) |
|
|
""" |
|
|
img_feat = img_feat + self.v_embeding |
|
|
B, L, C = img_feat.shape |
|
|
|
|
|
|
|
|
|
|
|
T = self.max_length |
|
|
zeros = img_feat.new_zeros((B, T, C)) |
|
|
zeros_len = img_feat.new_zeros(B) |
|
|
query = self.pos_encoder(zeros) |
|
|
|
|
|
|
|
|
v_embed = torch.cat((img_feat, self.l_mask.repeat(B, T, 1)), |
|
|
dim=1) |
|
|
padding_mask = _get_mask( |
|
|
self.max_length + zeros_len, |
|
|
self.max_length) |
|
|
v_mask = torch.zeros((1, 1, self.max_length, L), |
|
|
device=img_feat.device).tile([B, 1, 1, |
|
|
1]) |
|
|
mask = torch.cat((v_mask, padding_mask), 3) |
|
|
v_logits = self.forward_decoder(query, v_embed, mask=mask) |
|
|
|
|
|
|
|
|
if self.training and self.pretraining: |
|
|
tgt = torch.where(data[0] == self.ignore_index, 0, data[0]) |
|
|
tokens = F.one_hot(tgt, num_classes=self.out_channels) |
|
|
tokens = tokens.float() |
|
|
lengths = data[-1] |
|
|
else: |
|
|
tokens = torch.softmax(v_logits, dim=-1) |
|
|
lengths = _get_length(v_logits) |
|
|
tokens = tokens.detach() |
|
|
token_embed = self.proj(tokens) |
|
|
token_embed = self.token_encoder(token_embed) |
|
|
token_embed = token_embed + self.l_embeding |
|
|
|
|
|
padding_mask = _get_mask(lengths, |
|
|
self.max_length) |
|
|
mask = torch.cat((v_mask, padding_mask), 3) |
|
|
l_embed = torch.cat((self.v_mask.repeat(B, L, 1), token_embed), dim=1) |
|
|
l_logits = self.forward_decoder(query, l_embed, mask=mask) |
|
|
|
|
|
|
|
|
vl_embed = torch.cat((img_feat, token_embed), dim=1) |
|
|
vl_logits = self.forward_decoder(query, vl_embed, mask=mask) |
|
|
|
|
|
if self.training: |
|
|
return {'align': [vl_logits], 'lang': l_logits, 'vision': v_logits} |
|
|
else: |
|
|
return F.softmax(vl_logits, -1) |
|
|
|