| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| PyTorch CharacterBERT model: this is a variant of BERT that uses the CharacterCNN module from ELMo instead of a |
| WordPiece embedding matrix. See: “CharacterBERT: Reconciling ELMo and BERT for Word-Level Open-Vocabulary |
| Representations From Characters“ https://www.aclweb.org/anthology/2020.coling-main.609/ |
| """ |
|
|
| import math |
| import warnings |
| from dataclasses import dataclass |
| from typing import Callable, Optional, Tuple |
|
|
| import torch |
| import torch.utils.checkpoint |
| from torch import nn |
| from torch.nn import CrossEntropyLoss, MSELoss |
|
|
| from transformers.activations import ACT2FN |
| from transformers.file_utils import ( |
| ModelOutput, |
| add_code_sample_docstrings, |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| replace_return_docstrings, |
| ) |
| from transformers.modeling_outputs import ( |
| BaseModelOutputWithPastAndCrossAttentions, |
| BaseModelOutputWithPoolingAndCrossAttentions, |
| CausalLMOutputWithCrossAttentions, |
| MaskedLMOutput, |
| MultipleChoiceModelOutput, |
| NextSentencePredictorOutput, |
| QuestionAnsweringModelOutput, |
| SequenceClassifierOutput, |
| TokenClassifierOutput, |
| ) |
| from transformers.modeling_utils import ( |
| PreTrainedModel, |
| apply_chunking_to_forward, |
| find_pruneable_heads_and_indices, |
| prune_linear_layer, |
| ) |
| from transformers.utils import logging |
| from .configuration_character_bert import CharacterBertConfig |
| from .tokenization_character_bert import CharacterMapper |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| _CHECKPOINT_FOR_DOC = "helboukkouri/character-bert" |
| _CONFIG_FOR_DOC = "CharacterBertConfig" |
| _TOKENIZER_FOR_DOC = "CharacterBertTokenizer" |
|
|
| CHARACTER_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ |
| "helboukkouri/character-bert", |
| "helboukkouri/character-bert-medical", |
| |
| ] |
|
|
|
|
| |
| |
| class Highway(torch.nn.Module): |
| """ |
| A `Highway layer <https://arxiv.org/abs/1505.00387)>`__ does a gated combination of a linear transformation and a |
| non-linear transformation of its input. :math:`y = g * x + (1 - g) * f(A(x))`, where :math:`A` is a linear |
| transformation, :math:`f` is an element-wise non-linearity, and :math:`g` is an element-wise gate, computed as |
| :math:`sigmoid(B(x))`. |
| |
| This module will apply a fixed number of highway layers to its input, returning the final result. |
| |
| # Parameters |
| |
| input_dim : `int`, required The dimensionality of :math:`x`. We assume the input has shape `(batch_size, ..., |
| input_dim)`. num_layers : `int`, optional (default=`1`) The number of highway layers to apply to the input. |
| activation : `Callable[[torch.Tensor], torch.Tensor]`, optional (default=`torch.nn.functional.relu`) The |
| non-linearity to use in the highway layers. |
| """ |
|
|
| def __init__( |
| self, |
| input_dim: int, |
| num_layers: int = 1, |
| activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, |
| ) -> None: |
| super().__init__() |
| self._input_dim = input_dim |
| self._layers = torch.nn.ModuleList([torch.nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)]) |
| self._activation = activation |
| for layer in self._layers: |
| |
| |
| |
| |
| layer.bias[input_dim:].data.fill_(1) |
|
|
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| current_input = inputs |
| for layer in self._layers: |
| projected_input = layer(current_input) |
| linear_part = current_input |
| |
| |
| nonlinear_part, gate = projected_input.chunk(2, dim=-1) |
| nonlinear_part = self._activation(nonlinear_part) |
| gate = torch.sigmoid(gate) |
| current_input = gate * linear_part + (1 - gate) * nonlinear_part |
| return current_input |
|
|
|
|
| |
| |
| class CharacterCnn(torch.nn.Module): |
| """ |
| Computes context insensitive token representation using multiple CNNs. This embedder has input character ids of |
| size (batch_size, sequence_length, 50) and returns (batch_size, sequence_length, hidden_size), where hidden_size is |
| typically 768. |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.character_embeddings_dim = config.character_embeddings_dim |
| self.cnn_activation = config.cnn_activation |
| self.cnn_filters = config.cnn_filters |
| self.num_highway_layers = config.num_highway_layers |
| self.max_word_length = config.max_word_length |
| self.hidden_size = config.hidden_size |
| |
| |
| |
| self.character_vocab_size = 263 |
| self._init_weights() |
|
|
| def get_output_dim(self): |
| return self.hidden_size |
|
|
| def _init_weights(self): |
| self._init_char_embedding() |
| self._init_cnn_weights() |
| self._init_highway() |
| self._init_projection() |
|
|
| def _init_char_embedding(self): |
| weights = torch.empty((self.character_vocab_size, self.character_embeddings_dim)) |
| nn.init.normal_(weights) |
| weights[0].fill_(0.0) |
| weights[CharacterMapper.padding_character + 1].fill_(0.0) |
| self._char_embedding_weights = torch.nn.Parameter(torch.FloatTensor(weights), requires_grad=True) |
|
|
| def _init_cnn_weights(self): |
| convolutions = [] |
| for i, (width, num) in enumerate(self.cnn_filters): |
| conv = torch.nn.Conv1d( |
| in_channels=self.character_embeddings_dim, out_channels=num, kernel_size=width, bias=True |
| ) |
| conv.weight.requires_grad = True |
| conv.bias.requires_grad = True |
| convolutions.append(conv) |
| self.add_module(f"char_conv_{i}", conv) |
| self._convolutions = convolutions |
|
|
| def _init_highway(self): |
| |
| n_filters = sum(f[1] for f in self.cnn_filters) |
| self._highways = Highway(n_filters, self.num_highway_layers, activation=nn.functional.relu) |
| for k in range(self.num_highway_layers): |
| |
| |
| self._highways._layers[k].weight.requires_grad = True |
| self._highways._layers[k].bias.requires_grad = True |
|
|
| def _init_projection(self): |
| n_filters = sum(f[1] for f in self.cnn_filters) |
| self._projection = torch.nn.Linear(n_filters, self.hidden_size, bias=True) |
| self._projection.weight.requires_grad = True |
| self._projection.bias.requires_grad = True |
|
|
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| """ |
| Compute context insensitive token embeddings from characters. # Parameters inputs : `torch.Tensor` Shape |
| `(batch_size, sequence_length, 50)` of character ids representing the current batch. # Returns output: |
| `torch.Tensor` Shape `(batch_size, sequence_length, embedding_dim)` tensor with context insensitive token |
| representations. |
| """ |
|
|
| |
| |
| character_embedding = torch.nn.functional.embedding( |
| inputs.view(-1, self.max_word_length), self._char_embedding_weights |
| ) |
|
|
| |
| if self.cnn_activation == "tanh": |
| activation = torch.tanh |
| elif self.cnn_activation == "relu": |
| activation = torch.nn.functional.relu |
| else: |
| raise Exception("ConfigurationError: Unknown activation") |
|
|
| |
| character_embedding = torch.transpose(character_embedding, 1, 2) |
| convs = [] |
| for i in range(len(self._convolutions)): |
| conv = getattr(self, "char_conv_{}".format(i)) |
| convolved = conv(character_embedding) |
| |
| convolved, _ = torch.max(convolved, dim=-1) |
| convolved = activation(convolved) |
| convs.append(convolved) |
|
|
| |
| token_embedding = torch.cat(convs, dim=-1) |
|
|
| |
| token_embedding = self._highways(token_embedding) |
|
|
| |
| token_embedding = self._projection(token_embedding) |
|
|
| |
| batch_size, sequence_length, _ = inputs.size() |
| output = token_embedding.view(batch_size, sequence_length, -1) |
|
|
| return output |
|
|
|
|
| class CharacterBertEmbeddings(nn.Module): |
| """Construct the embeddings from word, position and token_type embeddings.""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.word_embeddings = CharacterCnn(config) |
| self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
| self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) |
|
|
| |
| |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
| |
| self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) |
| self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") |
|
|
| def forward( |
| self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 |
| ): |
| if input_ids is not None: |
| input_shape = input_ids[:, :, 0].size() |
| else: |
| input_shape = inputs_embeds.size()[:-1] |
|
|
| seq_length = input_shape[1] |
|
|
| if position_ids is None: |
| position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] |
|
|
| if token_type_ids is None: |
| token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.word_embeddings(input_ids) |
| token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
|
| embeddings = inputs_embeds + token_type_embeddings |
| if self.position_embedding_type == "absolute": |
| position_embeddings = self.position_embeddings(position_ids) |
| embeddings += position_embeddings |
| embeddings = self.LayerNorm(embeddings) |
| embeddings = self.dropout(embeddings) |
| return embeddings |
|
|
|
|
| |
| class CharacterBertSelfAttention(nn.Module): |
| def __init__(self, config, position_embedding_type=None): |
| super().__init__() |
| if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): |
| raise ValueError( |
| f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " |
| f"heads ({config.num_attention_heads})" |
| ) |
|
|
| self.num_attention_heads = config.num_attention_heads |
| self.attention_head_size = int(config.hidden_size / config.num_attention_heads) |
| self.all_head_size = self.num_attention_heads * self.attention_head_size |
|
|
| self.query = nn.Linear(config.hidden_size, self.all_head_size) |
| self.key = nn.Linear(config.hidden_size, self.all_head_size) |
| self.value = nn.Linear(config.hidden_size, self.all_head_size) |
|
|
| self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
| self.position_embedding_type = position_embedding_type or getattr( |
| config, "position_embedding_type", "absolute" |
| ) |
| if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": |
| self.max_position_embeddings = config.max_position_embeddings |
| self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) |
|
|
| self.is_decoder = config.is_decoder |
|
|
| def transpose_for_scores(self, x): |
| new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) |
| x = x.view(*new_x_shape) |
| return x.permute(0, 2, 1, 3) |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| head_mask=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| past_key_value=None, |
| output_attentions=False, |
| ): |
| mixed_query_layer = self.query(hidden_states) |
|
|
| |
| |
| |
| is_cross_attention = encoder_hidden_states is not None |
|
|
| if is_cross_attention and past_key_value is not None: |
| |
| key_layer = past_key_value[0] |
| value_layer = past_key_value[1] |
| attention_mask = encoder_attention_mask |
| elif is_cross_attention: |
| key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) |
| value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) |
| attention_mask = encoder_attention_mask |
| elif past_key_value is not None: |
| key_layer = self.transpose_for_scores(self.key(hidden_states)) |
| value_layer = self.transpose_for_scores(self.value(hidden_states)) |
| key_layer = torch.cat([past_key_value[0], key_layer], dim=2) |
| value_layer = torch.cat([past_key_value[1], value_layer], dim=2) |
| else: |
| key_layer = self.transpose_for_scores(self.key(hidden_states)) |
| value_layer = self.transpose_for_scores(self.value(hidden_states)) |
|
|
| query_layer = self.transpose_for_scores(mixed_query_layer) |
|
|
| if self.is_decoder: |
| |
| |
| |
| |
| |
| |
| |
| past_key_value = (key_layer, value_layer) |
|
|
| |
| attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
|
|
| if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": |
| seq_length = hidden_states.size()[1] |
| position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) |
| position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) |
| distance = position_ids_l - position_ids_r |
| positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) |
| positional_embedding = positional_embedding.to(dtype=query_layer.dtype) |
|
|
| if self.position_embedding_type == "relative_key": |
| relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
| attention_scores = attention_scores + relative_position_scores |
| elif self.position_embedding_type == "relative_key_query": |
| relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
| relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) |
| attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key |
|
|
| attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
| if attention_mask is not None: |
| |
| attention_scores = attention_scores + attention_mask |
|
|
| |
| attention_probs = nn.functional.softmax(attention_scores, dim=-1) |
|
|
| |
| |
| attention_probs = self.dropout(attention_probs) |
|
|
| |
| if head_mask is not None: |
| attention_probs = attention_probs * head_mask |
|
|
| context_layer = torch.matmul(attention_probs, value_layer) |
|
|
| context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
| context_layer = context_layer.view(*new_context_layer_shape) |
|
|
| outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) |
|
|
| if self.is_decoder: |
| outputs = outputs + (past_key_value,) |
| return outputs |
|
|
|
|
| |
| class CharacterBertSelfOutput(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
| def forward(self, hidden_states, input_tensor): |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.dropout(hidden_states) |
| hidden_states = self.LayerNorm(hidden_states + input_tensor) |
| return hidden_states |
|
|
|
|
| |
| class CharacterBertAttention(nn.Module): |
| def __init__(self, config, position_embedding_type=None): |
| super().__init__() |
| self.self = CharacterBertSelfAttention(config, position_embedding_type=position_embedding_type) |
| self.output = CharacterBertSelfOutput(config) |
| self.pruned_heads = set() |
|
|
| def prune_heads(self, heads): |
| if len(heads) == 0: |
| return |
| heads, index = find_pruneable_heads_and_indices( |
| heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads |
| ) |
|
|
| |
| self.self.query = prune_linear_layer(self.self.query, index) |
| self.self.key = prune_linear_layer(self.self.key, index) |
| self.self.value = prune_linear_layer(self.self.value, index) |
| self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) |
|
|
| |
| self.self.num_attention_heads = self.self.num_attention_heads - len(heads) |
| self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads |
| self.pruned_heads = self.pruned_heads.union(heads) |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| head_mask=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| past_key_value=None, |
| output_attentions=False, |
| ): |
| self_outputs = self.self( |
| hidden_states, |
| attention_mask, |
| head_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| past_key_value, |
| output_attentions, |
| ) |
| attention_output = self.output(self_outputs[0], hidden_states) |
| outputs = (attention_output,) + self_outputs[1:] |
| return outputs |
|
|
|
|
| |
| class CharacterBertIntermediate(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.hidden_size, config.intermediate_size) |
| if isinstance(config.hidden_act, str): |
| self.intermediate_act_fn = ACT2FN[config.hidden_act] |
| else: |
| self.intermediate_act_fn = config.hidden_act |
|
|
| def forward(self, hidden_states): |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.intermediate_act_fn(hidden_states) |
| return hidden_states |
|
|
|
|
| |
| class CharacterBertOutput(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.intermediate_size, config.hidden_size) |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
| def forward(self, hidden_states, input_tensor): |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.dropout(hidden_states) |
| hidden_states = self.LayerNorm(hidden_states + input_tensor) |
| return hidden_states |
|
|
|
|
| |
| class CharacterBertLayer(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.chunk_size_feed_forward = config.chunk_size_feed_forward |
| self.seq_len_dim = 1 |
| self.attention = CharacterBertAttention(config) |
| self.is_decoder = config.is_decoder |
| self.add_cross_attention = config.add_cross_attention |
| if self.add_cross_attention: |
| if not self.is_decoder: |
| raise ValueError(f"{self} should be used as a decoder model if cross attention is added") |
| self.crossattention = CharacterBertAttention(config, position_embedding_type="absolute") |
| self.intermediate = CharacterBertIntermediate(config) |
| self.output = CharacterBertOutput(config) |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| head_mask=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| past_key_value=None, |
| output_attentions=False, |
| ): |
| |
| self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None |
| self_attention_outputs = self.attention( |
| hidden_states, |
| attention_mask, |
| head_mask, |
| output_attentions=output_attentions, |
| past_key_value=self_attn_past_key_value, |
| ) |
| attention_output = self_attention_outputs[0] |
|
|
| |
| if self.is_decoder: |
| outputs = self_attention_outputs[1:-1] |
| present_key_value = self_attention_outputs[-1] |
| else: |
| outputs = self_attention_outputs[1:] |
|
|
| cross_attn_present_key_value = None |
| if self.is_decoder and encoder_hidden_states is not None: |
| if not hasattr(self, "crossattention"): |
| raise ValueError( |
| f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" |
| ) |
|
|
| |
| cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None |
| cross_attention_outputs = self.crossattention( |
| attention_output, |
| attention_mask, |
| head_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| cross_attn_past_key_value, |
| output_attentions, |
| ) |
| attention_output = cross_attention_outputs[0] |
| outputs = outputs + cross_attention_outputs[1:-1] |
|
|
| |
| cross_attn_present_key_value = cross_attention_outputs[-1] |
| present_key_value = present_key_value + cross_attn_present_key_value |
|
|
| layer_output = apply_chunking_to_forward( |
| self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output |
| ) |
| outputs = (layer_output,) + outputs |
|
|
| |
| if self.is_decoder: |
| outputs = outputs + (present_key_value,) |
|
|
| return outputs |
|
|
| def feed_forward_chunk(self, attention_output): |
| intermediate_output = self.intermediate(attention_output) |
| layer_output = self.output(intermediate_output, attention_output) |
| return layer_output |
|
|
|
|
| |
| class CharacterBertEncoder(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.layer = nn.ModuleList([CharacterBertLayer(config) for _ in range(config.num_hidden_layers)]) |
| self.gradient_checkpointing = False |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| head_mask=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| past_key_values=None, |
| use_cache=None, |
| output_attentions=False, |
| output_hidden_states=False, |
| return_dict=True, |
| ): |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attentions = () if output_attentions else None |
| all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None |
|
|
| next_decoder_cache = () if use_cache else None |
| for i, layer_module in enumerate(self.layer): |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| layer_head_mask = head_mask[i] if head_mask is not None else None |
| past_key_value = past_key_values[i] if past_key_values is not None else None |
|
|
| if self.gradient_checkpointing and self.training: |
|
|
| if use_cache: |
| logger.warning( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| ) |
| use_cache = False |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs, past_key_value, output_attentions) |
|
|
| return custom_forward |
|
|
| layer_outputs = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(layer_module), |
| hidden_states, |
| attention_mask, |
| layer_head_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| ) |
| else: |
| layer_outputs = layer_module( |
| hidden_states, |
| attention_mask, |
| layer_head_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| past_key_value, |
| output_attentions, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
| if use_cache: |
| next_decoder_cache += (layer_outputs[-1],) |
| if output_attentions: |
| all_self_attentions = all_self_attentions + (layer_outputs[1],) |
| if self.config.add_cross_attention: |
| all_cross_attentions = all_cross_attentions + (layer_outputs[2],) |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if not return_dict: |
| return tuple( |
| v |
| for v in [ |
| hidden_states, |
| next_decoder_cache, |
| all_hidden_states, |
| all_self_attentions, |
| all_cross_attentions, |
| ] |
| if v is not None |
| ) |
| return BaseModelOutputWithPastAndCrossAttentions( |
| last_hidden_state=hidden_states, |
| past_key_values=next_decoder_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attentions, |
| cross_attentions=all_cross_attentions, |
| ) |
|
|
|
|
| |
| class CharacterBertPooler(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| self.activation = nn.Tanh() |
|
|
| def forward(self, hidden_states): |
| |
| |
| first_token_tensor = hidden_states[:, 0] |
| pooled_output = self.dense(first_token_tensor) |
| pooled_output = self.activation(pooled_output) |
| return pooled_output |
|
|
|
|
| |
| class CharacterBertPredictionHeadTransform(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| if isinstance(config.hidden_act, str): |
| self.transform_act_fn = ACT2FN[config.hidden_act] |
| else: |
| self.transform_act_fn = config.hidden_act |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
| def forward(self, hidden_states): |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.transform_act_fn(hidden_states) |
| hidden_states = self.LayerNorm(hidden_states) |
| return hidden_states |
|
|
|
|
| class CharacterBertLMPredictionHead(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.transform = CharacterBertPredictionHeadTransform(config) |
|
|
| |
| |
| self.decoder = nn.Linear(config.hidden_size, config.mlm_vocab_size, bias=False) |
|
|
| self.bias = nn.Parameter(torch.zeros(config.mlm_vocab_size)) |
|
|
| |
| self.decoder.bias = self.bias |
|
|
| def forward(self, hidden_states): |
| hidden_states = self.transform(hidden_states) |
| hidden_states = self.decoder(hidden_states) |
| return hidden_states |
|
|
|
|
| |
| class CharacterBertOnlyMLMHead(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.predictions = CharacterBertLMPredictionHead(config) |
|
|
| def forward(self, sequence_output): |
| prediction_scores = self.predictions(sequence_output) |
| return prediction_scores |
|
|
|
|
| |
| class CharacterBertOnlyNSPHead(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.seq_relationship = nn.Linear(config.hidden_size, 2) |
|
|
| def forward(self, pooled_output): |
| seq_relationship_score = self.seq_relationship(pooled_output) |
| return seq_relationship_score |
|
|
|
|
| |
| class CharacterBertPreTrainingHeads(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.predictions = CharacterBertLMPredictionHead(config) |
| self.seq_relationship = nn.Linear(config.hidden_size, 2) |
|
|
| def forward(self, sequence_output, pooled_output): |
| prediction_scores = self.predictions(sequence_output) |
| seq_relationship_score = self.seq_relationship(pooled_output) |
| return prediction_scores, seq_relationship_score |
|
|
|
|
| class CharacterBertPreTrainedModel(PreTrainedModel): |
| """ |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| models. |
| """ |
|
|
| config_class = CharacterBertConfig |
| load_tf_weights = None |
| base_model_prefix = "character_bert" |
| _keys_to_ignore_on_load_missing = [r"position_ids"] |
|
|
| def _init_weights(self, module): |
| """Initialize the weights""" |
| if isinstance(module, CharacterCnn): |
| |
| module._char_embedding_weights.data.normal_() |
| |
| module._char_embedding_weights.data[0].fill_(0.0) |
| |
| module._char_embedding_weights.data[CharacterMapper.padding_character + 1].fill_(0.0) |
| if isinstance(module, nn.Linear): |
| |
| |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
|
|
|
|
| @dataclass |
| |
| class CharacterBertForPreTrainingOutput(ModelOutput): |
| """ |
| Output type of [`CharacterBertForPreTraining`]. |
| |
| Args: |
| loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): |
| Total loss as the sum of the masked language modeling loss and the next sequence prediction |
| (classification) loss. |
| prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
| Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
| seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`): |
| Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation |
| before SoftMax). |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of |
| shape `(batch_size, sequence_length, hidden_size)`. |
| |
| Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
| sequence_length)`. |
| |
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
| heads. |
| """ |
|
|
| loss: Optional[torch.FloatTensor] = None |
| prediction_logits: torch.FloatTensor = None |
| seq_relationship_logits: torch.FloatTensor = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
| CHARACTER_BERT_START_DOCSTRING = r""" |
| This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use |
| it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and |
| behavior. |
| |
| Parameters: |
| config (: |
| class:*~transformers.CharacterBertConfig*): Model configuration class with all the parameters of the model. |
| Initializing with a config file does not load the weights associated with the model, only the |
| configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model |
| weights. |
| """ |
|
|
| CHARACTER_BERT_INPUTS_DOCSTRING = r""" |
| Args: |
| input_ids (`torch.LongTensor` of shape `{0}`): |
| Indices of input sequence tokens. |
| |
| Indices can be obtained using [`CharacterBertTokenizer`]. See |
| [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for |
| details. |
| |
| [What are input IDs?](../glossary#input-ids) |
| attention_mask (`torch.FloatTensor` of shape `{1}`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) |
| token_type_ids (`torch.LongTensor` of shape `{1}`, *optional*): |
| Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: |
| |
| - 0 corresponds to a *sentence A* token, |
| - 1 corresponds to a *sentence B* token. |
| |
| [What are token type IDs?](../glossary#token-type-ids) |
| position_ids (`torch.LongTensor` of shape `{1}`, *optional*): |
| Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. |
| |
| [What are position IDs?](../glossary#position-ids) |
| head_mask (: |
| obj:*torch.FloatTensor* of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): Mask |
| to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: |
| |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked**. |
| |
| inputs_embeds (: |
| obj:*torch.FloatTensor* of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. |
| This is useful if you want more control over how to convert *input_ids* indices into associated vectors |
| than the model's internal embedding lookup matrix. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. |
| """ |
|
|
|
|
| @add_start_docstrings( |
| "The bare CharacterBERT Model transformer outputting raw hidden-states without any specific head on top.", |
| CHARACTER_BERT_START_DOCSTRING, |
| ) |
| class CharacterBertModel(CharacterBertPreTrainedModel): |
| """ |
| |
| The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of |
| cross-attention is added between the self-attention layers, following the architecture described in [Attention is |
| all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, |
| Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. |
| |
| To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration |
| set to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` |
| argument and `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an |
| input to the forward pass. |
| """ |
|
|
| def __init__(self, config, add_pooling_layer=True): |
| super().__init__(config) |
| self.config = config |
|
|
| self.embeddings = CharacterBertEmbeddings(config) |
| self.encoder = CharacterBertEncoder(config) |
|
|
| self.pooler = CharacterBertPooler(config) if add_pooling_layer else None |
|
|
| self.init_weights() |
|
|
| def get_input_embeddings(self): |
| return self.embeddings.word_embeddings |
|
|
| def set_input_embeddings(self, value): |
| self.embeddings.word_embeddings = value |
|
|
| def resize_token_embeddings(self, *args, **kwargs): |
| raise NotImplementedError("Cannot resize CharacterBERT's token embeddings.") |
|
|
| def _prune_heads(self, heads_to_prune): |
| """ |
| Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base |
| class PreTrainedModel |
| """ |
| for layer, heads in heads_to_prune.items(): |
| self.encoder.layer[layer].attention.prune_heads(heads) |
|
|
| @add_start_docstrings_to_model_forward( |
| CHARACTER_BERT_INPUTS_DOCSTRING.format( |
| "(batch_size, sequence_length, maximum_token_length)", "(batch_size, sequence_length)" |
| ) |
| ) |
| @add_code_sample_docstrings( |
| processor_class=_TOKENIZER_FOR_DOC, |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=BaseModelOutputWithPoolingAndCrossAttentions, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| past_key_values=None, |
| use_cache=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| r""" |
| encoder_hidden_states (: |
| obj:*torch.FloatTensor* of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence |
| of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model |
| is configured as a decoder. |
| encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in |
| the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| past_key_values (: |
| obj:*tuple(tuple(torch.FloatTensor))* of length `config.n_layers` with each tuple having 4 tensors of |
| shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key |
| and value hidden states of the attention blocks. Can be used to speed up decoding. If |
| `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` |
| (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` |
| instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up |
| decoding (see `past_key_values`). |
| """ |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if self.config.is_decoder: |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| else: |
| use_cache = False |
|
|
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
| elif input_ids is not None: |
| input_shape = input_ids.size()[:-1] |
| batch_size, seq_length = input_shape |
| elif inputs_embeds is not None: |
| input_shape = inputs_embeds.size()[:-1] |
| batch_size, seq_length = input_shape |
| else: |
| raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
| |
| past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) |
| if token_type_ids is None: |
| token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) |
|
|
| |
| |
| extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) |
|
|
| |
| |
| if self.config.is_decoder and encoder_hidden_states is not None: |
| encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() |
| encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) |
| if encoder_attention_mask is None: |
| encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) |
| encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
| else: |
| encoder_extended_attention_mask = None |
|
|
| |
| |
| |
| |
| |
| head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
|
|
| embedding_output = self.embeddings( |
| input_ids=input_ids, |
| position_ids=position_ids, |
| token_type_ids=token_type_ids, |
| inputs_embeds=inputs_embeds, |
| past_key_values_length=past_key_values_length, |
| ) |
| encoder_outputs = self.encoder( |
| embedding_output, |
| attention_mask=extended_attention_mask, |
| head_mask=head_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_extended_attention_mask, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| sequence_output = encoder_outputs[0] |
| pooled_output = self.pooler(sequence_output) if self.pooler is not None else None |
|
|
| if not return_dict: |
| return (sequence_output, pooled_output) + encoder_outputs[1:] |
|
|
| return BaseModelOutputWithPoolingAndCrossAttentions( |
| last_hidden_state=sequence_output, |
| pooler_output=pooled_output, |
| past_key_values=encoder_outputs.past_key_values, |
| hidden_states=encoder_outputs.hidden_states, |
| attentions=encoder_outputs.attentions, |
| cross_attentions=encoder_outputs.cross_attentions, |
| ) |
|
|
|
|
| @add_start_docstrings( |
| """ |
| CharacterBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a |
| `next sentence prediction (classification)` head. |
| """, |
| CHARACTER_BERT_START_DOCSTRING, |
| ) |
| class CharacterBertForPreTraining(CharacterBertPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.character_bert = CharacterBertModel(config) |
| self.cls = CharacterBertPreTrainingHeads(config) |
|
|
| self.init_weights() |
|
|
| def get_output_embeddings(self): |
| return self.cls.predictions.decoder |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.cls.predictions.decoder = new_embeddings |
|
|
| @add_start_docstrings_to_model_forward( |
| CHARACTER_BERT_INPUTS_DOCSTRING.format( |
| "(batch_size, sequence_length, maximum_token_length)", "(batch_size, sequence_length)" |
| ) |
| ) |
| @replace_return_docstrings(output_type=CharacterBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) |
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| labels=None, |
| next_sentence_label=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.mlm_vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.mlm_vocab_size]` |
| next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair |
| (see `input_ids` docstring) Indices should be in `[0, 1]`: |
| |
| - 0 indicates sequence B is a continuation of sequence A, |
| - 1 indicates sequence B is a random sequence. |
| kwargs (`Dict[str, any]`, optional, defaults to *{}*): |
| Used to hide legacy arguments that have been deprecated. |
| |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> from transformers import CharacterBertTokenizer, CharacterBertForPreTraining >>> import torch |
| |
| >>> tokenizer = CharacterBertTokenizer.from_pretrained('helboukkouri/character-bert') >>> model = |
| CharacterBertForPreTraining.from_pretrained('helboukkouri/character-bert') |
| |
| >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> outputs = model(**inputs) |
| |
| >>> prediction_logits = outputs.prediction_logits >>> seq_relationship_logits = |
| outputs.seq_relationship_logits |
| ``` |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| outputs = self.character_bert( |
| input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| sequence_output, pooled_output = outputs[:2] |
| prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) |
|
|
| total_loss = None |
| if labels is not None and next_sentence_label is not None: |
| loss_fct = CrossEntropyLoss() |
| masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.mlm_vocab_size), labels.view(-1)) |
| next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) |
| total_loss = masked_lm_loss + next_sentence_loss |
|
|
| if not return_dict: |
| output = (prediction_scores, seq_relationship_score) + outputs[2:] |
| return ((total_loss,) + output) if total_loss is not None else output |
|
|
| return CharacterBertForPreTrainingOutput( |
| loss=total_loss, |
| prediction_logits=prediction_scores, |
| seq_relationship_logits=seq_relationship_score, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| @add_start_docstrings( |
| """CharacterBert Model with a `language modeling` head on top for CLM fine-tuning.""", |
| CHARACTER_BERT_START_DOCSTRING, |
| ) |
| class CharacterBertLMHeadModel(CharacterBertPreTrainedModel): |
|
|
| _keys_to_ignore_on_load_unexpected = [r"pooler"] |
| _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| if not config.is_decoder: |
| logger.warning("If you want to use `CharacterBertLMHeadModel` as a standalone, add `is_decoder=True.`") |
|
|
| self.character_bert = CharacterBertModel(config, add_pooling_layer=False) |
| self.cls = CharacterBertOnlyMLMHead(config) |
|
|
| self.init_weights() |
|
|
| def get_output_embeddings(self): |
| return self.cls.predictions.decoder |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.cls.predictions.decoder = new_embeddings |
|
|
| @add_start_docstrings_to_model_forward( |
| CHARACTER_BERT_INPUTS_DOCSTRING.format( |
| "(batch_size, sequence_length, maximum_token_length)", "(batch_size, sequence_length)" |
| ) |
| ) |
| @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) |
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| labels=None, |
| past_key_values=None, |
| use_cache=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| r""" |
| encoder_hidden_states (: |
| obj:*torch.FloatTensor* of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence |
| of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model |
| is configured as a decoder. |
| encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in |
| the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in |
| `[-100, 0, ..., config.mlm_vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` |
| are ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.mlm_vocab_size]` |
| past_key_values (: |
| obj:*tuple(tuple(torch.FloatTensor))* of length `config.n_layers` with each tuple having 4 tensors of |
| shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key |
| and value hidden states of the attention blocks. Can be used to speed up decoding. |
| |
| If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` |
| (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` |
| instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up |
| decoding (see `past_key_values`). |
| |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> from transformers import CharacterBertTokenizer, CharacterBertLMHeadModel, CharacterBertConfig >>> |
| import torch |
| |
| >>> tokenizer = CharacterBertTokenizer.from_pretrained('helboukkouri/character-bert') >>> config = |
| CharacterBertConfig.from_pretrained("helboukkouri/character-bert") >>> config.is_decoder = True >>> model = |
| CharacterBertLMHeadModel.from_pretrained('helboukkouri/character-bert', config=config) |
| |
| >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> outputs = model(**inputs) |
| |
| >>> prediction_logits = outputs.logits |
| ``` |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| if labels is not None: |
| use_cache = False |
|
|
| outputs = self.character_bert( |
| input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| sequence_output = outputs[0] |
| prediction_scores = self.cls(sequence_output) |
|
|
| lm_loss = None |
| if labels is not None: |
| |
| shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() |
| labels = labels[:, 1:].contiguous() |
| loss_fct = CrossEntropyLoss() |
| lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.mlm_vocab_size), labels.view(-1)) |
|
|
| if not return_dict: |
| output = (prediction_scores,) + outputs[2:] |
| return ((lm_loss,) + output) if lm_loss is not None else output |
|
|
| return CausalLMOutputWithCrossAttentions( |
| loss=lm_loss, |
| logits=prediction_scores, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| cross_attentions=outputs.cross_attentions, |
| ) |
|
|
| def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): |
| input_shape = input_ids.shape |
| |
| if attention_mask is None: |
| attention_mask = input_ids.new_ones(input_shape) |
|
|
| |
| if past is not None: |
| input_ids = input_ids[:, -1:] |
|
|
| return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} |
|
|
| def _reorder_cache(self, past, beam_idx): |
| reordered_past = () |
| for layer_past in past: |
| reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) |
| return reordered_past |
|
|
|
|
| @add_start_docstrings( |
| """CharacterBert Model with a `language modeling` head on top.""", CHARACTER_BERT_START_DOCSTRING |
| ) |
| class CharacterBertForMaskedLM(CharacterBertPreTrainedModel): |
|
|
| _keys_to_ignore_on_load_unexpected = [r"pooler"] |
| _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| if config.is_decoder: |
| logger.warning( |
| "If you want to use `CharacterBertForMaskedLM` make sure `config.is_decoder=False` for " |
| "bi-directional self-attention." |
| ) |
| self.character_bert = CharacterBertModel(config, add_pooling_layer=False) |
| self.cls = CharacterBertOnlyMLMHead(config) |
|
|
| self.init_weights() |
|
|
| def get_output_embeddings(self): |
| return self.cls.predictions.decoder |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.cls.predictions.decoder = new_embeddings |
|
|
| @add_start_docstrings_to_model_forward( |
| CHARACTER_BERT_INPUTS_DOCSTRING.format( |
| "(batch_size, sequence_length, maximum_token_length)", "(batch_size, sequence_length)" |
| ) |
| ) |
| @add_code_sample_docstrings( |
| processor_class=_TOKENIZER_FOR_DOC, |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=MaskedLMOutput, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| labels=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.mlm_vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.mlm_vocab_size]` |
| """ |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| outputs = self.character_bert( |
| input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| sequence_output = outputs[0] |
| prediction_scores = self.cls(sequence_output) |
|
|
| masked_lm_loss = None |
| if labels is not None: |
| loss_fct = CrossEntropyLoss() |
| masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.mlm_vocab_size), labels.view(-1)) |
|
|
| if not return_dict: |
| output = (prediction_scores,) + outputs[2:] |
| return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output |
|
|
| return MaskedLMOutput( |
| loss=masked_lm_loss, |
| logits=prediction_scores, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): |
| input_shape = input_ids.shape |
| effective_batch_size = input_shape[0] |
|
|
| |
| assert self.config.pad_token_id is not None, "The PAD token should be defined for generation" |
| attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) |
| dummy_token = torch.full( |
| (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device |
| ) |
| input_ids = torch.cat([input_ids, dummy_token], dim=1) |
|
|
| return {"input_ids": input_ids, "attention_mask": attention_mask} |
|
|
|
|
| @add_start_docstrings( |
| """CharacterBert Model with a `next sentence prediction (classification)` head on top.""", |
| CHARACTER_BERT_START_DOCSTRING, |
| ) |
| class CharacterBertForNextSentencePrediction(CharacterBertPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.character_bert = CharacterBertModel(config) |
| self.cls = CharacterBertOnlyNSPHead(config) |
|
|
| self.init_weights() |
|
|
| @add_start_docstrings_to_model_forward( |
| CHARACTER_BERT_INPUTS_DOCSTRING.format( |
| "(batch_size, sequence_length, maximum_token_length)", "(batch_size, sequence_length)" |
| ) |
| ) |
| @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) |
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| labels=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| **kwargs |
| ): |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair |
| (see `input_ids` docstring). Indices should be in `[0, 1]`: |
| |
| - 0 indicates sequence B is a continuation of sequence A, |
| - 1 indicates sequence B is a random sequence. |
| |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> from transformers import CharacterBertTokenizer, CharacterBertForNextSentencePrediction >>> import |
| torch |
| |
| >>> tokenizer = CharacterBertTokenizer.from_pretrained('helboukkouri/character-bert') >>> model = |
| CharacterBertForNextSentencePrediction.from_pretrained('helboukkouri/character-bert') |
| |
| >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." |
| >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." >>> encoding = |
| tokenizer(prompt, next_sentence, return_tensors='pt') |
| |
| >>> outputs = model(**encoding, labels=torch.LongTensor([1])) >>> logits = outputs.logits >>> assert |
| logits[0, 0] < logits[0, 1] # next sentence was random |
| ``` |
| """ |
|
|
| if "next_sentence_label" in kwargs: |
| warnings.warn( |
| "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.", |
| FutureWarning, |
| ) |
| labels = kwargs.pop("next_sentence_label") |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| outputs = self.character_bert( |
| input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| pooled_output = outputs[1] |
|
|
| seq_relationship_scores = self.cls(pooled_output) |
|
|
| next_sentence_loss = None |
| if labels is not None: |
| loss_fct = CrossEntropyLoss() |
| next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) |
|
|
| if not return_dict: |
| output = (seq_relationship_scores,) + outputs[2:] |
| return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output |
|
|
| return NextSentencePredictorOutput( |
| loss=next_sentence_loss, |
| logits=seq_relationship_scores, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| @add_start_docstrings( |
| """ |
| CharacterBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the |
| pooled output) e.g. for GLUE tasks. |
| """, |
| CHARACTER_BERT_START_DOCSTRING, |
| ) |
| class CharacterBertForSequenceClassification(CharacterBertPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.character_bert = CharacterBertModel(config) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
| self.init_weights() |
|
|
| @add_start_docstrings_to_model_forward( |
| CHARACTER_BERT_INPUTS_DOCSTRING.format( |
| "(batch_size, sequence_length, maximum_token_length)", "(batch_size, sequence_length)" |
| ) |
| ) |
| @add_code_sample_docstrings( |
| processor_class=_TOKENIZER_FOR_DOC, |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=SequenceClassifierOutput, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| labels=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), |
| If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| outputs = self.character_bert( |
| input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| pooled_output = outputs[1] |
|
|
| pooled_output = self.dropout(pooled_output) |
| logits = self.classifier(pooled_output) |
|
|
| loss = None |
| if labels is not None: |
| if self.num_labels == 1: |
| |
| loss_fct = MSELoss() |
| loss = loss_fct(logits.view(-1), labels.view(-1)) |
| else: |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[2:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| @add_start_docstrings( |
| """ |
| CharacterBERT Model with a multiple choice classification head on top (a linear layer on top of the pooled output |
| and a softmax) e.g. for RocStories/SWAG tasks. |
| """, |
| CHARACTER_BERT_START_DOCSTRING, |
| ) |
| class CharacterBertForMultipleChoice(CharacterBertPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.character_bert = CharacterBertModel(config) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| self.classifier = nn.Linear(config.hidden_size, 1) |
|
|
| self.init_weights() |
|
|
| @add_start_docstrings_to_model_forward( |
| CHARACTER_BERT_INPUTS_DOCSTRING.format( |
| "(batch_size, sequence_length, maximum_token_length)", "(batch_size, sequence_length)" |
| ) |
| ) |
| @add_code_sample_docstrings( |
| processor_class=_TOKENIZER_FOR_DOC, |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=MultipleChoiceModelOutput, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| labels=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See |
| `input_ids` above) |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] |
|
|
| input_ids = input_ids.view(-1, input_ids.size(-2), input_ids.size(-1)) if input_ids is not None else None |
| attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None |
| token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None |
| position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None |
| inputs_embeds = ( |
| inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) |
| if inputs_embeds is not None |
| else None |
| ) |
|
|
| outputs = self.character_bert( |
| input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| pooled_output = outputs[1] |
|
|
| pooled_output = self.dropout(pooled_output) |
| logits = self.classifier(pooled_output) |
| reshaped_logits = logits.view(-1, num_choices) |
|
|
| loss = None |
| if labels is not None: |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(reshaped_logits, labels) |
|
|
| if not return_dict: |
| output = (reshaped_logits,) + outputs[2:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return MultipleChoiceModelOutput( |
| loss=loss, |
| logits=reshaped_logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| @add_start_docstrings( |
| """ |
| CharacterBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) |
| e.g. for Named-Entity-Recognition (NER) tasks. |
| """, |
| CHARACTER_BERT_START_DOCSTRING, |
| ) |
| class CharacterBertForTokenClassification(CharacterBertPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
|
|
| self.character_bert = CharacterBertModel(config) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
| self.init_weights() |
|
|
| @add_start_docstrings_to_model_forward( |
| CHARACTER_BERT_INPUTS_DOCSTRING.format( |
| "(batch_size, sequence_length, maximum_token_length)", "(batch_size, sequence_length)" |
| ) |
| ) |
| @add_code_sample_docstrings( |
| processor_class=_TOKENIZER_FOR_DOC, |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=TokenClassifierOutput, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| labels=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| outputs = self.character_bert( |
| input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| sequence_output = outputs[0] |
|
|
| sequence_output = self.dropout(sequence_output) |
| logits = self.classifier(sequence_output) |
|
|
| loss = None |
| if labels is not None: |
| loss_fct = CrossEntropyLoss() |
| |
| if attention_mask is not None: |
| active_loss = attention_mask.view(-1) == 1 |
| active_logits = logits.view(-1, self.num_labels) |
| active_labels = torch.where( |
| active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) |
| ) |
| loss = loss_fct(active_logits, active_labels) |
| else: |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[2:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return TokenClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| @add_start_docstrings( |
| """ |
| CharacterBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a |
| linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). |
| """, |
| CHARACTER_BERT_START_DOCSTRING, |
| ) |
| class CharacterBertForQuestionAnswering(CharacterBertPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
|
|
| config.num_labels = 2 |
| self.num_labels = config.num_labels |
|
|
| self.character_bert = CharacterBertModel(config) |
| self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) |
|
|
| self.init_weights() |
|
|
| @add_start_docstrings_to_model_forward( |
| CHARACTER_BERT_INPUTS_DOCSTRING.format( |
| "(batch_size, sequence_length, maximum_token_length)", "(batch_size, sequence_length)" |
| ) |
| ) |
| @add_code_sample_docstrings( |
| processor_class=_TOKENIZER_FOR_DOC, |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=QuestionAnsweringModelOutput, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| start_positions=None, |
| end_positions=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| r""" |
| start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for position (index) of the start of the labelled span for computing the token classification loss. |
| Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the |
| sequence are not taken into account for computing the loss. |
| end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for position (index) of the end of the labelled span for computing the token classification loss. |
| Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the |
| sequence are not taken into account for computing the loss. |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| outputs = self.character_bert( |
| input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| sequence_output = outputs[0] |
|
|
| logits = self.qa_outputs(sequence_output) |
| start_logits, end_logits = logits.split(1, dim=-1) |
| start_logits = start_logits.squeeze(-1) |
| end_logits = end_logits.squeeze(-1) |
|
|
| total_loss = None |
| if start_positions is not None and end_positions is not None: |
| |
| if len(start_positions.size()) > 1: |
| start_positions = start_positions.squeeze(-1) |
| if len(end_positions.size()) > 1: |
| end_positions = end_positions.squeeze(-1) |
| |
| ignored_index = start_logits.size(1) |
| start_positions.clamp_(0, ignored_index) |
| end_positions.clamp_(0, ignored_index) |
|
|
| loss_fct = CrossEntropyLoss(ignore_index=ignored_index) |
| start_loss = loss_fct(start_logits, start_positions) |
| end_loss = loss_fct(end_logits, end_positions) |
| total_loss = (start_loss + end_loss) / 2 |
|
|
| if not return_dict: |
| output = (start_logits, end_logits) + outputs[2:] |
| return ((total_loss,) + output) if total_loss is not None else output |
|
|
| return QuestionAnsweringModelOutput( |
| loss=total_loss, |
| start_logits=start_logits, |
| end_logits=end_logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |