| | |
| | |
| | |
| | import math |
| | from typing import Dict, Optional, Tuple |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.utils.checkpoint as checkpoint |
| | from mmcv.cnn.bricks import DropPath |
| | from torch import Tensor |
| |
|
| | try: |
| | from transformers import BertConfig, BertPreTrainedModel |
| | from transformers.modeling_utils import apply_chunking_to_forward |
| | from transformers.models.bert.modeling_bert import \ |
| | BertAttention as HFBertAttention |
| | from transformers.models.bert.modeling_bert import \ |
| | BertIntermediate as HFBertIntermediate |
| | from transformers.models.bert.modeling_bert import \ |
| | BertOutput as HFBertOutput |
| | except ImportError: |
| | BertConfig = None |
| | BertPreTrainedModel = object |
| | apply_chunking_to_forward = None |
| | HFBertAttention = object |
| | HFBertIntermediate = object |
| | HFBertOutput = object |
| |
|
| | MAX_CLAMP_VALUE = 50000 |
| |
|
| |
|
| | def permute_and_flatten(layer: Tensor, N: int, A: int, C: int, H: int, |
| | W: int) -> Tensor: |
| | """Permute and then flatten a tensor, |
| | |
| | from size (N, A, C, H, W) to (N, H * W * A, C). |
| | |
| | Args: |
| | layer (Tensor): Tensor of shape (N, C, H, W). |
| | N (int): Batch size. |
| | A (int): Number of attention heads. |
| | C (int): Number of channels. |
| | H (int): Height of feature map. |
| | W (int): Width of feature map. |
| | |
| | Returns: |
| | Tensor: A Tensor of shape (N, H * W * A, C). |
| | """ |
| | layer = layer.view(N, A, C, H, W) |
| | layer = layer.permute(0, 3, 4, 1, 2) |
| | layer = layer.reshape(N, -1, C) |
| | return layer |
| |
|
| |
|
| | def clamp_values(vector: Tensor) -> Tensor: |
| | """Clamp the values of a vector to the range [-MAX_CLAMP_VALUE, |
| | MAX_CLAMP_VALUE]. |
| | |
| | Args: |
| | vector (Tensor): Tensor of shape (N, C, H, W). |
| | |
| | Returns: |
| | Tensor: A Tensor of shape (N, C, H, W) with clamped values. |
| | """ |
| | vector = torch.clamp(vector, min=-MAX_CLAMP_VALUE, max=MAX_CLAMP_VALUE) |
| | return vector |
| |
|
| |
|
| | class BiMultiHeadAttention(nn.Module): |
| | """Bidirectional fusion Multi-Head Attention layer. |
| | |
| | Args: |
| | v_dim (int): The dimension of the vision input. |
| | l_dim (int): The dimension of the language input. |
| | embed_dim (int): The embedding dimension for the attention operation. |
| | num_heads (int): The number of attention heads. |
| | dropout (float, optional): The dropout probability. Defaults to 0.1. |
| | """ |
| |
|
| | def __init__(self, |
| | v_dim: int, |
| | l_dim: int, |
| | embed_dim: int, |
| | num_heads: int, |
| | dropout: float = 0.1): |
| | super(BiMultiHeadAttention, self).__init__() |
| |
|
| | self.embed_dim = embed_dim |
| | self.num_heads = num_heads |
| | self.head_dim = embed_dim // num_heads |
| | self.v_dim = v_dim |
| | self.l_dim = l_dim |
| |
|
| | assert ( |
| | self.head_dim * self.num_heads == self.embed_dim |
| | ), 'embed_dim must be divisible by num_heads ' \ |
| | f'(got `embed_dim`: {self.embed_dim} ' \ |
| | f'and `num_heads`: {self.num_heads}).' |
| | self.scale = self.head_dim**(-0.5) |
| | self.dropout = dropout |
| |
|
| | self.v_proj = nn.Linear(self.v_dim, self.embed_dim) |
| | self.l_proj = nn.Linear(self.l_dim, self.embed_dim) |
| | self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim) |
| | self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim) |
| |
|
| | self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim) |
| | self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim) |
| |
|
| | self.stable_softmax_2d = False |
| | self.clamp_min_for_underflow = True |
| | self.clamp_max_for_overflow = True |
| |
|
| | self._reset_parameters() |
| |
|
| | def _shape(self, tensor: Tensor, seq_len: int, bsz: int): |
| | return tensor.view(bsz, seq_len, self.num_heads, |
| | self.head_dim).transpose(1, 2).contiguous() |
| |
|
| | def _reset_parameters(self): |
| | nn.init.xavier_uniform_(self.v_proj.weight) |
| | self.v_proj.bias.data.fill_(0) |
| | nn.init.xavier_uniform_(self.l_proj.weight) |
| | self.l_proj.bias.data.fill_(0) |
| | nn.init.xavier_uniform_(self.values_v_proj.weight) |
| | self.values_v_proj.bias.data.fill_(0) |
| | nn.init.xavier_uniform_(self.values_l_proj.weight) |
| | self.values_l_proj.bias.data.fill_(0) |
| | nn.init.xavier_uniform_(self.out_v_proj.weight) |
| | self.out_v_proj.bias.data.fill_(0) |
| | nn.init.xavier_uniform_(self.out_l_proj.weight) |
| | self.out_l_proj.bias.data.fill_(0) |
| |
|
| | def forward( |
| | self, |
| | vision: Tensor, |
| | lang: Tensor, |
| | attention_mask_v: Optional[Tensor] = None, |
| | attention_mask_l: Optional[Tensor] = None, |
| | ) -> Tuple[Tensor, Tensor]: |
| | bsz, tgt_len, _ = vision.size() |
| |
|
| | query_states = self.v_proj(vision) * self.scale |
| | key_states = self._shape(self.l_proj(lang), -1, bsz) |
| | value_v_states = self._shape(self.values_v_proj(vision), -1, bsz) |
| | value_l_states = self._shape(self.values_l_proj(lang), -1, bsz) |
| |
|
| | proj_shape = (bsz * self.num_heads, -1, self.head_dim) |
| | query_states = self._shape(query_states, tgt_len, |
| | bsz).view(*proj_shape) |
| | key_states = key_states.view(*proj_shape) |
| | value_v_states = value_v_states.view(*proj_shape) |
| | value_l_states = value_l_states.view(*proj_shape) |
| |
|
| | src_len = key_states.size(1) |
| | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) |
| |
|
| | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): |
| | raise ValueError( |
| | f'Attention weights should be of ' |
| | f'size {(bsz * self.num_heads, tgt_len, src_len)}, ' |
| | f'but is {attn_weights.size()}') |
| |
|
| | if self.stable_softmax_2d: |
| | attn_weights = attn_weights - attn_weights.max() |
| |
|
| | if self.clamp_min_for_underflow: |
| | |
| | attn_weights = torch.clamp(attn_weights, min=-MAX_CLAMP_VALUE) |
| | if self.clamp_max_for_overflow: |
| | |
| | attn_weights = torch.clamp(attn_weights, max=MAX_CLAMP_VALUE) |
| |
|
| | attn_weights_T = attn_weights.transpose(1, 2) |
| | attn_weights_l = ( |
| | attn_weights_T - |
| | torch.max(attn_weights_T, dim=-1, keepdim=True)[0]) |
| | if self.clamp_min_for_underflow: |
| | |
| | attn_weights_l = torch.clamp(attn_weights_l, min=-MAX_CLAMP_VALUE) |
| | if self.clamp_max_for_overflow: |
| | |
| | attn_weights_l = torch.clamp(attn_weights_l, max=MAX_CLAMP_VALUE) |
| |
|
| | if attention_mask_v is not None: |
| | attention_mask_v = ( |
| | attention_mask_v[:, None, |
| | None, :].repeat(1, self.num_heads, 1, |
| | 1).flatten(0, 1)) |
| | attn_weights_l.masked_fill_(attention_mask_v, float('-inf')) |
| |
|
| | attn_weights_l = attn_weights_l.softmax(dim=-1) |
| |
|
| | if attention_mask_l is not None: |
| | assert (attention_mask_l.dim() == 2) |
| | attention_mask = attention_mask_l.unsqueeze(1).unsqueeze(1) |
| | attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len) |
| | attention_mask = attention_mask.masked_fill( |
| | attention_mask == 0, -9e15) |
| |
|
| | if attention_mask.size() != (bsz, 1, tgt_len, src_len): |
| | raise ValueError('Attention mask should be of ' |
| | f'size {(bsz, 1, tgt_len, src_len)}') |
| | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, |
| | src_len) + attention_mask |
| | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, |
| | src_len) |
| |
|
| | attn_weights_v = nn.functional.softmax(attn_weights, dim=-1) |
| |
|
| | attn_probs_v = F.dropout( |
| | attn_weights_v, p=self.dropout, training=self.training) |
| | attn_probs_l = F.dropout( |
| | attn_weights_l, p=self.dropout, training=self.training) |
| |
|
| | attn_output_v = torch.bmm(attn_probs_v, value_l_states) |
| | attn_output_l = torch.bmm(attn_probs_l, value_v_states) |
| |
|
| | if attn_output_v.size() != (bsz * self.num_heads, tgt_len, |
| | self.head_dim): |
| | raise ValueError( |
| | '`attn_output_v` should be of ' |
| | f'size {(bsz, self.num_heads, tgt_len, self.head_dim)}, ' |
| | f'but is {attn_output_v.size()}') |
| |
|
| | if attn_output_l.size() != (bsz * self.num_heads, src_len, |
| | self.head_dim): |
| | raise ValueError( |
| | '`attn_output_l` should be of size ' |
| | f'{(bsz, self.num_heads, src_len, self.head_dim)}, ' |
| | f'but is {attn_output_l.size()}') |
| |
|
| | attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, |
| | self.head_dim) |
| | attn_output_v = attn_output_v.transpose(1, 2) |
| | attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim) |
| |
|
| | attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, |
| | self.head_dim) |
| | attn_output_l = attn_output_l.transpose(1, 2) |
| | attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim) |
| |
|
| | attn_output_v = self.out_v_proj(attn_output_v) |
| | attn_output_l = self.out_l_proj(attn_output_l) |
| |
|
| | return attn_output_v, attn_output_l |
| |
|
| |
|
| | class BiAttentionBlock(nn.Module): |
| | """BiAttentionBlock Module: |
| | |
| | First, multi-level visual features are concat; Then the concat visual |
| | feature and lang feature are fused by attention; Finally the newly visual |
| | feature are split into multi levels. |
| | |
| | Args: |
| | v_dim (int): The dimension of the visual features. |
| | l_dim (int): The dimension of the language feature. |
| | embed_dim (int): The embedding dimension for the attention operation. |
| | num_heads (int): The number of attention heads. |
| | dropout (float, optional): The dropout probability. Defaults to 0.1. |
| | drop_path (float, optional): The drop path probability. |
| | Defaults to 0.0. |
| | init_values (float, optional): |
| | The initial value for the scaling parameter. |
| | Defaults to 1e-4. |
| | """ |
| |
|
| | def __init__(self, |
| | v_dim: int, |
| | l_dim: int, |
| | embed_dim: int, |
| | num_heads: int, |
| | dropout: float = 0.1, |
| | drop_path: float = .0, |
| | init_values: float = 1e-4): |
| | super().__init__() |
| |
|
| | |
| | self.layer_norm_v = nn.LayerNorm(v_dim) |
| | self.layer_norm_l = nn.LayerNorm(l_dim) |
| | self.attn = BiMultiHeadAttention( |
| | v_dim=v_dim, |
| | l_dim=l_dim, |
| | embed_dim=embed_dim, |
| | num_heads=num_heads, |
| | dropout=dropout) |
| |
|
| | |
| | self.drop_path = DropPath( |
| | drop_path) if drop_path > 0. else nn.Identity() |
| | self.gamma_v = nn.Parameter( |
| | init_values * torch.ones(v_dim), requires_grad=True) |
| | self.gamma_l = nn.Parameter( |
| | init_values * torch.ones(l_dim), requires_grad=True) |
| |
|
| | def forward(self, |
| | vf0: Tensor, |
| | vf1: Tensor, |
| | vf2: Tensor, |
| | vf3: Tensor, |
| | vf4: Tensor, |
| | lang_feature: Tensor, |
| | attention_mask_l=None): |
| | visual_features = [vf0, vf1, vf2, vf3, vf4] |
| | size_per_level, visual_features_flatten = [], [] |
| | for i, feat_per_level in enumerate(visual_features): |
| | bs, c, h, w = feat_per_level.shape |
| | size_per_level.append([h, w]) |
| | feat = permute_and_flatten(feat_per_level, bs, -1, c, h, w) |
| | visual_features_flatten.append(feat) |
| | visual_features_flatten = torch.cat(visual_features_flatten, dim=1) |
| | new_v, new_lang_feature = self.single_attention_call( |
| | visual_features_flatten, |
| | lang_feature, |
| | attention_mask_l=attention_mask_l) |
| | |
| | new_v = new_v.transpose(1, 2).contiguous() |
| |
|
| | start = 0 |
| | |
| | fvfs = [] |
| | for (h, w) in size_per_level: |
| | new_v_per_level = new_v[:, :, |
| | start:start + h * w].view(bs, -1, h, |
| | w).contiguous() |
| | fvfs.append(new_v_per_level) |
| | start += h * w |
| |
|
| | return fvfs[0], fvfs[1], fvfs[2], fvfs[3], fvfs[4], new_lang_feature |
| |
|
| | def single_attention_call( |
| | self, |
| | visual: Tensor, |
| | lang: Tensor, |
| | attention_mask_v: Optional[Tensor] = None, |
| | attention_mask_l: Optional[Tensor] = None, |
| | ) -> Tuple[Tensor, Tensor]: |
| | """Perform a single attention call between the visual and language |
| | inputs. |
| | |
| | Args: |
| | visual (Tensor): The visual input tensor. |
| | lang (Tensor): The language input tensor. |
| | attention_mask_v (Optional[Tensor]): |
| | An optional attention mask tensor for the visual input. |
| | attention_mask_l (Optional[Tensor]): |
| | An optional attention mask tensor for the language input. |
| | |
| | Returns: |
| | Tuple[Tensor, Tensor]: A tuple containing the updated |
| | visual and language tensors after the attention call. |
| | """ |
| | visual = self.layer_norm_v(visual) |
| | lang = self.layer_norm_l(lang) |
| | delta_v, delta_l = self.attn( |
| | visual, |
| | lang, |
| | attention_mask_v=attention_mask_v, |
| | attention_mask_l=attention_mask_l) |
| | |
| | visual = visual + self.drop_path(self.gamma_v * delta_v) |
| | lang = lang + self.drop_path(self.gamma_l * delta_l) |
| | return visual, lang |
| |
|
| |
|
| | class SingleScaleBiAttentionBlock(BiAttentionBlock): |
| | """This is a single-scale implementation of `BiAttentionBlock`. |
| | |
| | The only differenece between it and `BiAttentionBlock` is that the |
| | `forward` function of `SingleScaleBiAttentionBlock` only accepts a single |
| | flatten visual feature map, while the `forward` function in |
| | `BiAttentionBlock` accepts multiple visual feature maps. |
| | """ |
| |
|
| | def forward(self, |
| | visual_feature: Tensor, |
| | lang_feature: Tensor, |
| | attention_mask_v=None, |
| | attention_mask_l=None): |
| | """Single-scale forward pass. |
| | |
| | Args: |
| | visual_feature (Tensor): The visual input tensor. Tensor of |
| | shape (bs, patch_len, ch). |
| | lang_feature (Tensor): The language input tensor. Tensor of |
| | shape (bs, text_len, ch). |
| | attention_mask_v (_type_, optional): Visual feature attention |
| | mask. Defaults to None. |
| | attention_mask_l (_type_, optional): Language feature attention |
| | mask.Defaults to None. |
| | """ |
| | new_v, new_lang_feature = self.single_attention_call( |
| | visual_feature, |
| | lang_feature, |
| | attention_mask_v=attention_mask_v, |
| | attention_mask_l=attention_mask_l) |
| | return new_v, new_lang_feature |
| |
|
| |
|
| | class VLFuse(nn.Module): |
| | """Early Fusion Module. |
| | |
| | Args: |
| | v_dim (int): Dimension of visual features. |
| | l_dim (int): Dimension of language features. |
| | embed_dim (int): The embedding dimension for the attention operation. |
| | num_heads (int): Number of attention heads. |
| | dropout (float): Dropout probability. |
| | drop_path (float): Drop path probability. |
| | use_checkpoint (bool): Whether to use PyTorch's checkpoint function. |
| | """ |
| |
|
| | def __init__(self, |
| | v_dim: int = 256, |
| | l_dim: int = 768, |
| | embed_dim: int = 2048, |
| | num_heads: int = 8, |
| | dropout: float = 0.1, |
| | drop_path: float = 0.0, |
| | use_checkpoint: bool = False): |
| | super().__init__() |
| | self.use_checkpoint = use_checkpoint |
| | self.b_attn = BiAttentionBlock( |
| | v_dim=v_dim, |
| | l_dim=l_dim, |
| | embed_dim=embed_dim, |
| | num_heads=num_heads, |
| | dropout=dropout, |
| | drop_path=drop_path, |
| | init_values=1.0 / 6.0) |
| |
|
| | def forward(self, x: dict) -> dict: |
| | """Forward pass of the VLFuse module.""" |
| | visual_features = x['visual'] |
| | language_dict_features = x['lang'] |
| |
|
| | if self.use_checkpoint: |
| | |
| | |
| | |
| | vf0, vf1, vf2, vf3, vf4, language_features = checkpoint.checkpoint( |
| | self.b_attn, *visual_features, |
| | language_dict_features['hidden'], |
| | language_dict_features['masks']) |
| | else: |
| | vf0, vf1, vf2, vf3, vf4, language_features = self.b_attn( |
| | *visual_features, language_dict_features['hidden'], |
| | language_dict_features['masks']) |
| |
|
| | language_dict_features['hidden'] = language_features |
| | fused_language_dict_features = language_dict_features |
| |
|
| | features_dict = { |
| | 'visual': [vf0, vf1, vf2, vf3, vf4], |
| | 'lang': fused_language_dict_features |
| | } |
| |
|
| | return features_dict |
| |
|
| |
|
| | class BertEncoderLayer(BertPreTrainedModel): |
| | """A modified version of the `BertLayer` class from the |
| | `transformers.models.bert.modeling_bert` module. |
| | |
| | Args: |
| | config (:class:`~transformers.BertConfig`): |
| | The configuration object that |
| | contains various parameters for the model. |
| | clamp_min_for_underflow (bool, optional): |
| | Whether to clamp the minimum value of the hidden states |
| | to prevent underflow. Defaults to `False`. |
| | clamp_max_for_overflow (bool, optional): |
| | Whether to clamp the maximum value of the hidden states |
| | to prevent overflow. Defaults to `False`. |
| | """ |
| |
|
| | def __init__(self, |
| | config: BertConfig, |
| | clamp_min_for_underflow: bool = False, |
| | clamp_max_for_overflow: bool = False): |
| | super().__init__(config) |
| | self.config = config |
| | self.chunk_size_feed_forward = config.chunk_size_feed_forward |
| | self.seq_len_dim = 1 |
| |
|
| | self.attention = BertAttention(config, clamp_min_for_underflow, |
| | clamp_max_for_overflow) |
| | self.intermediate = BertIntermediate(config) |
| | self.output = BertOutput(config) |
| |
|
| | def forward( |
| | self, inputs: Dict[str, Dict[str, torch.Tensor]] |
| | ) -> Dict[str, Dict[str, torch.Tensor]]: |
| | """Applies the BertEncoderLayer to the input features.""" |
| | language_dict_features = inputs['lang'] |
| | hidden_states = language_dict_features['hidden'] |
| | attention_mask = language_dict_features['masks'] |
| |
|
| | device = hidden_states.device |
| | input_shape = hidden_states.size()[:-1] |
| | extended_attention_mask = self.get_extended_attention_mask( |
| | attention_mask, input_shape, device) |
| |
|
| | self_attention_outputs = self.attention( |
| | hidden_states, |
| | extended_attention_mask, |
| | None, |
| | output_attentions=False, |
| | past_key_value=None) |
| | attention_output = self_attention_outputs[0] |
| | outputs = self_attention_outputs[1:] |
| | layer_output = apply_chunking_to_forward(self.feed_forward_chunk, |
| | self.chunk_size_feed_forward, |
| | self.seq_len_dim, |
| | attention_output) |
| | outputs = (layer_output, ) + outputs |
| | hidden_states = outputs[0] |
| |
|
| | language_dict_features['hidden'] = hidden_states |
| |
|
| | features_dict = { |
| | 'visual': inputs['visual'], |
| | 'lang': language_dict_features |
| | } |
| |
|
| | return features_dict |
| |
|
| | def feed_forward_chunk(self, attention_output: Tensor) -> Tensor: |
| | """Applies the intermediate and output layers of the BertEncoderLayer |
| | to a chunk of the input sequence.""" |
| | intermediate_output = self.intermediate(attention_output) |
| | layer_output = self.output(intermediate_output, attention_output) |
| | return layer_output |
| |
|
| |
|
| | |
| | |
| | class BertSelfAttention(nn.Module): |
| | """BERT self-attention layer from Huggingface transformers. |
| | |
| | Compared to the BertSelfAttention of Huggingface, only add the clamp. |
| | |
| | Args: |
| | config (:class:`~transformers.BertConfig`): |
| | The configuration object that |
| | contains various parameters for the model. |
| | clamp_min_for_underflow (bool, optional): |
| | Whether to clamp the minimum value of the hidden states |
| | to prevent underflow. Defaults to `False`. |
| | clamp_max_for_overflow (bool, optional): |
| | Whether to clamp the maximum value of the hidden states |
| | to prevent overflow. Defaults to `False`. |
| | """ |
| |
|
| | def __init__(self, |
| | config: BertConfig, |
| | clamp_min_for_underflow: bool = False, |
| | clamp_max_for_overflow: bool = False): |
| | super().__init__() |
| | if config.hidden_size % config.num_attention_heads != 0 and \ |
| | not hasattr(config, 'embedding_size'): |
| | raise ValueError(f'The hidden size ({config.hidden_size}) is ' |
| | 'not a multiple of the number of attention ' |
| | f'heads ({config.num_attention_heads})') |
| |
|
| | self.num_attention_heads = config.num_attention_heads |
| | self.attention_head_size = int(config.hidden_size / |
| | config.num_attention_heads) |
| | self.all_head_size = self.num_attention_heads * \ |
| | self.attention_head_size |
| |
|
| | self.query = nn.Linear(config.hidden_size, self.all_head_size) |
| | self.key = nn.Linear(config.hidden_size, self.all_head_size) |
| | self.value = nn.Linear(config.hidden_size, self.all_head_size) |
| |
|
| | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
| | self.position_embedding_type = getattr(config, |
| | 'position_embedding_type', |
| | 'absolute') |
| | if self.position_embedding_type == 'relative_key' or \ |
| | self.position_embedding_type == 'relative_key_query': |
| | self.max_position_embeddings = config.max_position_embeddings |
| | self.distance_embedding = nn.Embedding( |
| | 2 * config.max_position_embeddings - 1, |
| | self.attention_head_size) |
| | self.clamp_min_for_underflow = clamp_min_for_underflow |
| | self.clamp_max_for_overflow = clamp_max_for_overflow |
| |
|
| | self.is_decoder = config.is_decoder |
| |
|
| | def transpose_for_scores(self, x: Tensor) -> Tensor: |
| | """Transpose the dimensions of `x`.""" |
| | new_x_shape = x.size()[:-1] + (self.num_attention_heads, |
| | self.attention_head_size) |
| | x = x.view(*new_x_shape) |
| | return x.permute(0, 2, 1, 3) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: Tensor, |
| | attention_mask: Optional[Tensor] = None, |
| | head_mask: Optional[Tensor] = None, |
| | encoder_hidden_states: Optional[Tensor] = None, |
| | encoder_attention_mask: Optional[Tensor] = None, |
| | past_key_value: Optional[Tuple[Tensor, Tensor]] = None, |
| | output_attentions: bool = False, |
| | ) -> Tuple[Tensor, ...]: |
| | """Perform a forward pass through the BERT self-attention layer.""" |
| |
|
| | mixed_query_layer = self.query(hidden_states) |
| |
|
| | |
| | |
| | |
| | is_cross_attention = encoder_hidden_states is not None |
| |
|
| | if is_cross_attention and past_key_value is not None: |
| | |
| | key_layer = past_key_value[0] |
| | value_layer = past_key_value[1] |
| | attention_mask = encoder_attention_mask |
| | elif is_cross_attention: |
| | key_layer = self.transpose_for_scores( |
| | self.key(encoder_hidden_states)) |
| | value_layer = self.transpose_for_scores( |
| | self.value(encoder_hidden_states)) |
| | attention_mask = encoder_attention_mask |
| | elif past_key_value is not None: |
| | key_layer = self.transpose_for_scores(self.key(hidden_states)) |
| | value_layer = self.transpose_for_scores(self.value(hidden_states)) |
| | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) |
| | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) |
| | else: |
| | key_layer = self.transpose_for_scores(self.key(hidden_states)) |
| | value_layer = self.transpose_for_scores(self.value(hidden_states)) |
| |
|
| | query_layer = self.transpose_for_scores(mixed_query_layer) |
| |
|
| | if self.is_decoder: |
| | past_key_value = (key_layer, value_layer) |
| |
|
| | |
| | |
| | attention_scores = torch.matmul(query_layer, |
| | key_layer.transpose(-1, -2)) |
| |
|
| | if self.position_embedding_type == 'relative_key' or \ |
| | self.position_embedding_type == 'relative_key_query': |
| | seq_length = hidden_states.size()[1] |
| | position_ids_l = torch.arange( |
| | seq_length, dtype=torch.long, |
| | device=hidden_states.device).view(-1, 1) |
| | position_ids_r = torch.arange( |
| | seq_length, dtype=torch.long, |
| | device=hidden_states.device).view(1, -1) |
| | distance = position_ids_l - position_ids_r |
| | positional_embedding = self.distance_embedding( |
| | distance + self.max_position_embeddings - 1) |
| | positional_embedding = positional_embedding.to( |
| | dtype=query_layer.dtype) |
| |
|
| | if self.position_embedding_type == 'relative_key': |
| | relative_position_scores = torch.einsum( |
| | 'bhld,lrd->bhlr', query_layer, positional_embedding) |
| | attention_scores = attention_scores + relative_position_scores |
| | elif self.position_embedding_type == 'relative_key_query': |
| | relative_position_scores_query = torch.einsum( |
| | 'bhld,lrd->bhlr', query_layer, positional_embedding) |
| | relative_position_scores_key = torch.einsum( |
| | 'bhrd,lrd->bhlr', key_layer, positional_embedding) |
| | attention_scores = attention_scores + \ |
| | relative_position_scores_query + \ |
| | relative_position_scores_key |
| |
|
| | attention_scores = attention_scores / math.sqrt( |
| | self.attention_head_size) |
| |
|
| | if self.clamp_min_for_underflow: |
| | attention_scores = torch.clamp( |
| | attention_scores, min=-MAX_CLAMP_VALUE |
| | ) |
| | if self.clamp_max_for_overflow: |
| | attention_scores = torch.clamp( |
| | attention_scores, max=MAX_CLAMP_VALUE |
| | ) |
| |
|
| | if attention_mask is not None: |
| | |
| | |
| | attention_scores = attention_scores + attention_mask |
| |
|
| | |
| | attention_probs = nn.Softmax(dim=-1)(attention_scores) |
| |
|
| | |
| | |
| | attention_probs = self.dropout(attention_probs) |
| |
|
| | |
| | if head_mask is not None: |
| | attention_probs = attention_probs * head_mask |
| |
|
| | context_layer = torch.matmul(attention_probs, value_layer) |
| |
|
| | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| | new_context_layer_shape = context_layer.size()[:-2] + ( |
| | self.all_head_size, ) |
| | context_layer = context_layer.view(*new_context_layer_shape) |
| |
|
| | outputs = (context_layer, |
| | attention_probs) if output_attentions else (context_layer, ) |
| |
|
| | if self.is_decoder: |
| | outputs = outputs + (past_key_value, ) |
| | return outputs |
| |
|
| |
|
| | class BertAttention(HFBertAttention): |
| | """BertAttention is made up of self-attention and intermediate+output. |
| | |
| | Compared to the BertAttention of Huggingface, only add the clamp. |
| | |
| | Args: |
| | config (:class:`~transformers.BertConfig`): |
| | The configuration object that |
| | contains various parameters for the model. |
| | clamp_min_for_underflow (bool, optional): |
| | Whether to clamp the minimum value of the hidden states |
| | to prevent underflow. Defaults to `False`. |
| | clamp_max_for_overflow (bool, optional): |
| | Whether to clamp the maximum value of the hidden states |
| | to prevent overflow. Defaults to `False`. |
| | """ |
| |
|
| | def __init__(self, |
| | config: BertConfig, |
| | clamp_min_for_underflow: bool = False, |
| | clamp_max_for_overflow: bool = False): |
| | super().__init__(config) |
| | self.self = BertSelfAttention(config, clamp_min_for_underflow, |
| | clamp_max_for_overflow) |
| |
|
| |
|
| | class BertIntermediate(HFBertIntermediate): |
| | """Modified from transformers.models.bert.modeling_bert.BertIntermediate. |
| | |
| | Compared to the BertIntermediate of Huggingface, only add the clamp. |
| | """ |
| |
|
| | def forward(self, hidden_states: Tensor) -> Tensor: |
| | hidden_states = self.dense(hidden_states) |
| | hidden_states = clamp_values(hidden_states) |
| | hidden_states = self.intermediate_act_fn(hidden_states) |
| | hidden_states = clamp_values(hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | class BertOutput(HFBertOutput): |
| | """Modified from transformers.models.bert.modeling_bert.BertOutput. |
| | |
| | Compared to the BertOutput of Huggingface, only add the clamp. |
| | """ |
| |
|
| | def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor: |
| | hidden_states = self.dense(hidden_states) |
| | hidden_states = self.dropout(hidden_states) |
| | hidden_states = clamp_values(hidden_states) |
| | hidden_states = self.LayerNorm(hidden_states + input_tensor) |
| | hidden_states = clamp_values(hidden_states) |
| | return hidden_states |
| |
|