Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import numpy as np | |
| from fastai.vision import * | |
| from modules_matrn.attention import * | |
| from modules_matrn.model import Model, _default_tfmer_cfg | |
| from modules_matrn.transformer import (PositionalEncoding, | |
| TransformerEncoder, | |
| TransformerEncoderLayer) | |
| class BaseSemanticVisual_backbone_feature(Model): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| d_model = ifnone(config.model_alignment_d_model, _default_tfmer_cfg['d_model']) | |
| nhead = ifnone(config.model_alignment_nhead, _default_tfmer_cfg['nhead']) | |
| d_inner = ifnone(config.model_alignment_d_inner, _default_tfmer_cfg['d_inner']) | |
| dropout = ifnone(config.model_alignmentl_dropout, _default_tfmer_cfg['dropout']) | |
| activation = ifnone(config.model_alignment_activation, _default_tfmer_cfg['activation']) | |
| num_layers = ifnone(config.model_alignment_num_layers, 2) | |
| self.mask_example_prob = ifnone(config.model_alignment_mask_example_prob, 0.9) | |
| self.mask_candidate_prob = ifnone(config.model_alignment_mask_candidate_prob, 0.9) | |
| self.num_vis_mask = ifnone(config.model_alignment_num_vis_mask, 10) | |
| self.nhead = nhead | |
| self.d_model = d_model | |
| self.use_self_attn = ifnone(config.model_alignment_use_self_attn, False) | |
| self.loss_weight = ifnone(config.model_alignment_loss_weight, 1.0) | |
| self.max_length = config.dataset_max_length + 1 # additional stop token | |
| self.debug = ifnone(config.global_debug, False) | |
| encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead, | |
| dim_feedforward=d_inner, dropout=dropout, activation=activation) | |
| self.model1 = TransformerEncoder(encoder_layer, num_layers) | |
| self.pos_encoder_tfm = PositionalEncoding(d_model, dropout=0, max_len=8*32) | |
| mode = ifnone(config.model_alignment_attention_mode, 'nearest') | |
| self.model2_vis = PositionAttention( | |
| max_length=config.dataset_max_length + 1, # additional stop token | |
| mode=mode | |
| ) | |
| self.cls_vis = nn.Linear(d_model, self.charset.num_classes) | |
| self.cls_sem = nn.Linear(d_model, self.charset.num_classes) | |
| self.w_att = nn.Linear(2 * d_model, d_model) | |
| v_token = torch.empty((1, d_model)) | |
| self.v_token = nn.Parameter(v_token) | |
| torch.nn.init.uniform_(self.v_token, -0.001, 0.001) | |
| self.cls = nn.Linear(d_model, self.charset.num_classes) | |
| def forward(self, l_feature, v_feature, lengths_l=None, v_attn=None, l_logits=None, texts=None, training=True): | |
| """ | |
| Args: | |
| l_feature: (N, T, E) where T is length, N is batch size and d is dim of model | |
| v_feature: (N, E, H, W) | |
| lengths_l: (N,) | |
| v_attn: (N, T, H, W) | |
| l_logits: (N, T, C) | |
| texts: (N, T, C) | |
| """ | |
| padding_mask = self._get_padding_mask(lengths_l, self.max_length) | |
| l_feature = l_feature.permute(1, 0, 2) # (T, N, E) | |
| N, E, H, W = v_feature.size() | |
| v_feature = v_feature.view(N, E, H*W).contiguous().permute(2, 0, 1) # (H*W, N, E) | |
| if training: | |
| n, t, h, w = v_attn.shape | |
| v_attn = v_attn.view(n, t, -1) # (N, T, H*W) | |
| for idx, length in enumerate(lengths_l): | |
| if np.random.random() <= self.mask_example_prob: | |
| l_idx = np.random.randint(int(length)) | |
| v_random_idx = v_attn[idx, l_idx].argsort(descending=True).cpu().numpy()[:self.num_vis_mask,] | |
| v_random_idx = v_random_idx[np.random.random(v_random_idx.shape) <= self.mask_candidate_prob] | |
| v_feature[v_random_idx, idx] = self.v_token | |
| if len(v_attn.shape) == 4: | |
| n, t, h, w = v_attn.shape | |
| v_attn = v_attn.view(n, t, -1) # (N, T, H*W) | |
| zeros = v_feature.new_zeros((h*w, n, E)) # (H*W, N, E) | |
| base_pos = self.pos_encoder_tfm(zeros) # (H*W, N, E) | |
| base_pos = base_pos.permute(1, 0, 2) # (N, H*W, E) | |
| base_pos = torch.bmm(v_attn, base_pos) # (N, T, E) | |
| base_pos = base_pos.permute(1, 0, 2) # (T, N, E) | |
| l_feature = l_feature + base_pos | |
| sv_feature = torch.cat((v_feature, l_feature), dim=0) # (H*W+T, N, E) | |
| sv_feature = self.model1(sv_feature) # (H*W+T, N, E) | |
| sv_to_v_feature = sv_feature[:H*W] # (H*W, N, E) | |
| sv_to_s_feature = sv_feature[H*W:] # (T, N, E) | |
| sv_to_v_feature = sv_to_v_feature.permute(1, 2, 0).view(N, E, H, W) | |
| sv_to_v_feature, _ = self.model2_vis(sv_to_v_feature) # (N, T, E) | |
| sv_to_v_logits = self.cls_vis(sv_to_v_feature) # (N, T, C) | |
| pt_v_lengths = self._get_length(sv_to_v_logits) # (N,) | |
| sv_to_s_feature = sv_to_s_feature.permute(1, 0, 2) # (N, T, E) | |
| sv_to_s_logits = self.cls_sem(sv_to_s_feature) # (N, T, C) | |
| pt_s_lengths = self._get_length(sv_to_s_logits) # (N,) | |
| f = torch.cat((sv_to_v_feature, sv_to_s_feature), dim=2) | |
| f_att = torch.sigmoid(self.w_att(f)) | |
| output = f_att * sv_to_v_feature + (1 - f_att) * sv_to_s_feature | |
| logits = self.cls(output) # (N, T, C) | |
| pt_lengths = self._get_length(logits) | |
| return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight':self.loss_weight*3, | |
| 'v_logits': sv_to_v_logits, 'pt_v_lengths': pt_v_lengths, | |
| 's_logits': sv_to_s_logits, 'pt_s_lengths': pt_s_lengths, | |
| 'name': 'alignment'} | |