| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """PyTorch MMBT model. """ |
|
|
|
|
| import logging |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import CrossEntropyLoss, MSELoss |
|
|
| from .file_utils import add_start_docstrings |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ModalEmbeddings(nn.Module): |
| """Generic Modal Embeddings which takes in an encoder, and a transformer embedding. |
| """ |
|
|
| def __init__(self, config, encoder, embeddings): |
| super().__init__() |
| self.config = config |
| self.encoder = encoder |
| self.proj_embeddings = nn.Linear(config.modal_hidden_size, config.hidden_size) |
| self.position_embeddings = embeddings.position_embeddings |
| self.token_type_embeddings = embeddings.token_type_embeddings |
| self.word_embeddings = embeddings.word_embeddings |
| self.LayerNorm = embeddings.LayerNorm |
| self.dropout = nn.Dropout(p=config.hidden_dropout_prob) |
|
|
| def forward(self, input_modal, start_token=None, end_token=None, position_ids=None, token_type_ids=None): |
| token_embeddings = self.proj_embeddings(self.encoder(input_modal)) |
| seq_length = token_embeddings.size(1) |
|
|
| if start_token is not None: |
| start_token_embeds = self.word_embeddings(start_token) |
| seq_length += 1 |
| token_embeddings = torch.cat([start_token_embeds.unsqueeze(1), token_embeddings], dim=1) |
|
|
| if end_token is not None: |
| end_token_embeds = self.word_embeddings(end_token) |
| seq_length += 1 |
| token_embeddings = torch.cat([token_embeddings, end_token_embeds.unsqueeze(1)], dim=1) |
|
|
| if position_ids is None: |
| position_ids = torch.arange(seq_length, dtype=torch.long, device=input_modal.device) |
| position_ids = position_ids.unsqueeze(0).expand(input_modal.size(0), seq_length) |
|
|
| if token_type_ids is None: |
| token_type_ids = torch.zeros( |
| (input_modal.size(0), seq_length), dtype=torch.long, device=input_modal.device |
| ) |
|
|
| position_embeddings = self.position_embeddings(position_ids) |
| token_type_embeddings = self.token_type_embeddings(token_type_ids) |
| embeddings = token_embeddings + position_embeddings + token_type_embeddings |
| embeddings = self.LayerNorm(embeddings) |
| embeddings = self.dropout(embeddings) |
| return embeddings |
|
|
|
|
| MMBT_START_DOCSTRING = r""" MMBT model was proposed in |
| `Supervised Multimodal Bitransformers for Classifying Images and Text`_ |
| by Douwe Kiela, Suvrat Bhooshan, Hamed Firooz, Davide Testuggine. |
| It's a supervised multimodal bitransformer model that fuses information from text and other image encoders, |
| and obtain state-of-the-art performance on various multimodal classification benchmark tasks. |
| |
| This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and |
| refer to the PyTorch documentation for all matter related to general usage and behavior. |
| |
| .. _`Supervised Multimodal Bitransformers for Classifying Images and Text`: |
| https://github.com/facebookresearch/mmbt |
| |
| .. _`torch.nn.Module`: |
| https://pytorch.org/docs/stable/nn.html#module |
| |
| Parameters: |
| config (:class:`~transformers.MMBTConfig`): Model configuration class with all the parameters of the model. |
| Initializing with a config file does not load the weights associated with the model, only the configuration. |
| transformer (:class: `~nn.Module`): A text transformer that is used by MMBT. |
| It should have embeddings, encoder, and pooler attributes. |
| encoder (:class: `~nn.Module`): Encoder for the second modality. |
| It should take in a batch of modal inputs and return k, n dimension embeddings. |
| """ |
|
|
| MMBT_INPUTS_DOCSTRING = r""" Inputs: |
| **input_modal**: ``torch.FloatTensor`` of shape ``(batch_size, ***)``: |
| The other modality data. It will be the shape that the encoder for that type expects. |
| e.g. With an Image Encoder, the shape would be (batch_size, channels, height, width) |
| **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: |
| Indices of input sequence tokens in the vocabulary. |
| It does not expect [CLS] token to be added as it's appended to the end of other modality embeddings. |
| See :func:`transformers.PreTrainedTokenizer.encode` and |
| :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. |
| **modal_start_tokens**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: |
| Optional start token to be added to Other Modality Embedding. [CLS] Most commonly used for Classification tasks. |
| **modal_end_tokens**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: |
| Optional end token to be added to Other Modality Embedding. [SEP] Most commonly used. |
| **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: |
| Mask to avoid performing attention on padding token indices. |
| Mask values selected in ``[0, 1]``: |
| ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. |
| **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: |
| Segment token indices to indicate different portions of the inputs. |
| **modal_token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, modal_sequence_length)``: |
| Segment token indices to indicate different portions of the non-text modality. |
| The embeddings from these tokens will be summed with the respective token embeddings for the non-text modality. |
| **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: |
| Indices of positions of each input sequence tokens in the position embeddings. |
| **modal_position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, modal_sequence_length)``: |
| Indices of positions of each input sequence tokens in the position embeddings for the non-text modality. |
| **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: |
| Mask to nullify selected heads of the self-attention modules. |
| Mask values selected in ``[0, 1]``: |
| ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. |
| **inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``: |
| Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation. |
| This is useful if you want more control over how to convert `input_ids` indices into associated vectors |
| than the model's internal embedding lookup matrix. |
| **encoder_hidden_states**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``: |
| Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model |
| is configured as a decoder. |
| **encoder_attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: |
| Mask to avoid performing attention on the padding token indices of the encoder input. This mask |
| is used in the cross-attention if the model is configured as a decoder. |
| Mask values selected in ``[0, 1]``: |
| ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. |
| """ |
|
|
|
|
| @add_start_docstrings( |
| "The bare MMBT Model outputting raw hidden-states without any specific head on top.", |
| MMBT_START_DOCSTRING, |
| MMBT_INPUTS_DOCSTRING, |
| ) |
| class MMBTModel(nn.Module): |
| r""" |
| Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: |
| **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` |
| Sequence of hidden-states at the output of the last layer of the model. |
| **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` |
| Last layer hidden-state of the first token of the sequence (classification token) |
| further processed by a Linear layer and a Tanh activation function. The Linear |
| layer weights are trained from the next sentence prediction (classification) |
| objective during Bert pretraining. This output is usually *not* a good summary |
| of the semantic content of the input, you're often better with averaging or pooling |
| the sequence of hidden-states for the whole input sequence. |
| **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) |
| list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) |
| of shape ``(batch_size, sequence_length, hidden_size)``: |
| Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
| **attentions**: (`optional`, returned when ``config.output_attentions=True``) |
| list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: |
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. |
| |
| Examples:: |
| |
| # For example purposes. Not runnable. |
| transformer = BertModel.from_pretrained('bert-base-uncased') |
| encoder = ImageEncoder(args) |
| mmbt = MMBTModel(config, transformer, encoder) |
| """ |
|
|
| def __init__(self, config, transformer, encoder): |
| super().__init__() |
| self.config = config |
| self.transformer = transformer |
| self.modal_encoder = ModalEmbeddings(config, encoder, transformer.embeddings) |
|
|
| def forward( |
| self, |
| input_modal, |
| input_ids=None, |
| modal_start_tokens=None, |
| modal_end_tokens=None, |
| attention_mask=None, |
| token_type_ids=None, |
| modal_token_type_ids=None, |
| position_ids=None, |
| modal_position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| ): |
|
|
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
| elif input_ids is not None: |
| input_txt_shape = input_ids.size() |
| elif inputs_embeds is not None: |
| input_txt_shape = inputs_embeds.size()[:-1] |
| else: |
| raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
| modal_embeddings = self.modal_encoder( |
| input_modal, |
| start_token=modal_start_tokens, |
| end_token=modal_end_tokens, |
| position_ids=modal_position_ids, |
| token_type_ids=modal_token_type_ids, |
| ) |
|
|
| input_modal_shape = modal_embeddings.size()[:-1] |
|
|
| if token_type_ids is None: |
| token_type_ids = torch.ones(input_txt_shape, dtype=torch.long, device=device) |
|
|
| txt_embeddings = self.transformer.embeddings( |
| input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds |
| ) |
|
|
| embedding_output = torch.cat([modal_embeddings, txt_embeddings], 1) |
|
|
| input_shape = embedding_output.size()[:-1] |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones(input_shape, device=device) |
| else: |
| attention_mask = torch.cat( |
| [torch.ones(input_modal_shape, device=device, dtype=torch.long), attention_mask], dim=1 |
| ) |
|
|
| if encoder_attention_mask is None: |
| encoder_attention_mask = torch.ones(input_shape, device=device) |
| else: |
| encoder_attention_mask = torch.cat( |
| [torch.ones(input_modal_shape, device=device), encoder_attention_mask], dim=1 |
| ) |
|
|
| |
| |
| if attention_mask.dim() == 3: |
| extended_attention_mask = attention_mask[:, None, :, :] |
|
|
| |
| |
| |
| if attention_mask.dim() == 2: |
| if self.config.is_decoder: |
| batch_size, seq_length = input_shape |
| seq_ids = torch.arange(seq_length, device=device) |
| causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] |
| extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] |
| else: |
| extended_attention_mask = attention_mask[:, None, None, :] |
|
|
| |
| |
| |
| |
| |
| extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) |
| extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 |
|
|
| |
| |
| if encoder_attention_mask.dim() == 3: |
| encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] |
| if encoder_attention_mask.dim() == 2: |
| encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] |
|
|
| encoder_extended_attention_mask = encoder_extended_attention_mask.to( |
| dtype=next(self.parameters()).dtype |
| ) |
| encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 |
|
|
| |
| |
| |
| |
| |
| if head_mask is not None: |
| if head_mask.dim() == 1: |
| head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) |
| head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) |
| elif head_mask.dim() == 2: |
| head_mask = ( |
| head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) |
| ) |
| head_mask = head_mask.to( |
| dtype=next(self.parameters()).dtype |
| ) |
| else: |
| head_mask = [None] * self.config.num_hidden_layers |
|
|
| encoder_outputs = self.transformer.encoder( |
| embedding_output, |
| attention_mask=extended_attention_mask, |
| head_mask=head_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_extended_attention_mask, |
| ) |
|
|
| sequence_output = encoder_outputs[0] |
| pooled_output = self.transformer.pooler(sequence_output) |
|
|
| outputs = (sequence_output, pooled_output,) + encoder_outputs[ |
| 1: |
| ] |
| return outputs |
|
|
| def get_input_embeddings(self): |
| return self.embeddings.word_embeddings |
|
|
| def set_input_embeddings(self, value): |
| self.embeddings.word_embeddings = value |
|
|
|
|
| @add_start_docstrings( |
| """MMBT Model with a sequence classification/regression head on top (a linear layer on top of |
| the pooled output)""", |
| MMBT_START_DOCSTRING, |
| MMBT_INPUTS_DOCSTRING, |
| ) |
| class MMBTForClassification(nn.Module): |
| r""" |
| **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: |
| Labels for computing the sequence classification/regression loss. |
| Indices should be in ``[0, ..., config.num_labels - 1]``. |
| If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), |
| If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). |
| |
| Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: |
| **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: |
| Classification (or regression if config.num_labels==1) loss. |
| **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)`` |
| Classification (or regression if config.num_labels==1) scores (before SoftMax). |
| **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) |
| list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) |
| of shape ``(batch_size, sequence_length, hidden_size)``: |
| Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
| **attentions**: (`optional`, returned when ``config.output_attentions=True``) |
| list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: |
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. |
| |
| Examples:: |
| |
| # For example purposes. Not runnable. |
| transformer = BertModel.from_pretrained('bert-base-uncased') |
| encoder = ImageEncoder(args) |
| model = MMBTForClassification(config, transformer, encoder) |
| outputs = model(input_modal, input_ids, labels=labels) |
| loss, logits = outputs[:2] |
| """ |
|
|
| def __init__(self, config, transformer, encoder): |
| super().__init__() |
| self.num_labels = config.num_labels |
|
|
| self.mmbt = MMBTModel(config, transformer, encoder) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
| def forward( |
| self, |
| input_modal, |
| input_ids=None, |
| modal_start_tokens=None, |
| modal_end_tokens=None, |
| attention_mask=None, |
| token_type_ids=None, |
| modal_token_type_ids=None, |
| position_ids=None, |
| modal_position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| labels=None, |
| ): |
|
|
| outputs = self.mmbt( |
| input_modal=input_modal, |
| input_ids=input_ids, |
| modal_start_tokens=modal_start_tokens, |
| modal_end_tokens=modal_end_tokens, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| modal_token_type_ids=modal_token_type_ids, |
| position_ids=position_ids, |
| modal_position_ids=modal_position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| ) |
|
|
| pooled_output = outputs[1] |
|
|
| pooled_output = self.dropout(pooled_output) |
| logits = self.classifier(pooled_output) |
|
|
| outputs = (logits,) + outputs[2:] |
|
|
| if labels is not None: |
| if self.num_labels == 1: |
| |
| loss_fct = MSELoss() |
| loss = loss_fct(logits.view(-1), labels.view(-1)) |
| else: |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| outputs = (loss,) + outputs |
|
|
| return outputs |
|
|