| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """PyTorch ALBERT model. """ |
|
|
| import logging |
| import math |
| import os |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import CrossEntropyLoss, MSELoss |
|
|
| from transformers.configuration_albert import AlbertConfig |
| from transformers.modeling_bert import ACT2FN, BertEmbeddings, BertSelfAttention, prune_linear_layer |
| from transformers.modeling_utils import PreTrainedModel |
|
|
| from .file_utils import add_start_docstrings, add_start_docstrings_to_callable |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = { |
| "albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-pytorch_model.bin", |
| "albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-pytorch_model.bin", |
| "albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-pytorch_model.bin", |
| "albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-pytorch_model.bin", |
| "albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-pytorch_model.bin", |
| "albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-pytorch_model.bin", |
| "albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-pytorch_model.bin", |
| "albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-pytorch_model.bin", |
| } |
|
|
|
|
| def load_tf_weights_in_albert(model, config, tf_checkpoint_path): |
| """ Load tf checkpoints in a pytorch model.""" |
| try: |
| import re |
| import numpy as np |
| import tensorflow as tf |
| except ImportError: |
| logger.error( |
| "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " |
| "https://www.tensorflow.org/install/ for installation instructions." |
| ) |
| raise |
| tf_path = os.path.abspath(tf_checkpoint_path) |
| logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) |
| |
| init_vars = tf.train.list_variables(tf_path) |
| names = [] |
| arrays = [] |
| for name, shape in init_vars: |
| logger.info("Loading TF weight {} with shape {}".format(name, shape)) |
| array = tf.train.load_variable(tf_path, name) |
| names.append(name) |
| arrays.append(array) |
|
|
| for name, array in zip(names, arrays): |
| print(name) |
|
|
| for name, array in zip(names, arrays): |
| original_name = name |
|
|
| |
| name = name.replace("module/", "") |
|
|
| |
| name = name.replace("ffn_1", "ffn") |
| name = name.replace("bert/", "albert/") |
| name = name.replace("attention_1", "attention") |
| name = name.replace("transform/", "") |
| name = name.replace("LayerNorm_1", "full_layer_layer_norm") |
| name = name.replace("LayerNorm", "attention/LayerNorm") |
| name = name.replace("transformer/", "") |
|
|
| |
| name = name.replace("intermediate/dense/", "") |
| name = name.replace("ffn/intermediate/output/dense/", "ffn_output/") |
|
|
| |
| name = name.replace("/output/", "/") |
| name = name.replace("/self/", "/") |
|
|
| |
| name = name.replace("pooler/dense", "pooler") |
|
|
| |
| name = name.replace("cls/predictions", "predictions") |
| name = name.replace("predictions/attention", "predictions") |
|
|
| |
| name = name.replace("embeddings/attention", "embeddings") |
| name = name.replace("inner_group_", "albert_layers/") |
| name = name.replace("group_", "albert_layer_groups/") |
|
|
| |
| if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name): |
| name = "classifier/" + name |
|
|
| |
| if "seq_relationship" in name: |
| continue |
|
|
| name = name.split("/") |
|
|
| |
| if ( |
| "adam_m" in name |
| or "adam_v" in name |
| or "AdamWeightDecayOptimizer" in name |
| or "AdamWeightDecayOptimizer_1" in name |
| or "global_step" in name |
| ): |
| logger.info("Skipping {}".format("/".join(name))) |
| continue |
|
|
| pointer = model |
| for m_name in name: |
| if re.fullmatch(r"[A-Za-z]+_\d+", m_name): |
| scope_names = re.split(r"_(\d+)", m_name) |
| else: |
| scope_names = [m_name] |
|
|
| if scope_names[0] == "kernel" or scope_names[0] == "gamma": |
| pointer = getattr(pointer, "weight") |
| elif scope_names[0] == "output_bias" or scope_names[0] == "beta": |
| pointer = getattr(pointer, "bias") |
| elif scope_names[0] == "output_weights": |
| pointer = getattr(pointer, "weight") |
| elif scope_names[0] == "squad": |
| pointer = getattr(pointer, "classifier") |
| else: |
| try: |
| pointer = getattr(pointer, scope_names[0]) |
| except AttributeError: |
| logger.info("Skipping {}".format("/".join(name))) |
| continue |
| if len(scope_names) >= 2: |
| num = int(scope_names[1]) |
| pointer = pointer[num] |
|
|
| if m_name[-11:] == "_embeddings": |
| pointer = getattr(pointer, "weight") |
| elif m_name == "kernel": |
| array = np.transpose(array) |
| try: |
| assert pointer.shape == array.shape |
| except AssertionError as e: |
| e.args += (pointer.shape, array.shape) |
| raise |
| print("Initialize PyTorch weight {} from {}".format(name, original_name)) |
| pointer.data = torch.from_numpy(array) |
|
|
| return model |
|
|
|
|
| class AlbertEmbeddings(BertEmbeddings): |
| """ |
| Construct the embeddings from word, position and token_type embeddings. |
| """ |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=0) |
| self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size) |
| self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size) |
| self.LayerNorm = torch.nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) |
|
|
|
|
| class AlbertAttention(BertSelfAttention): |
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.output_attentions = config.output_attentions |
| self.num_attention_heads = config.num_attention_heads |
| self.hidden_size = config.hidden_size |
| self.attention_head_size = config.hidden_size // config.num_attention_heads |
| self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.pruned_heads = set() |
|
|
| def prune_heads(self, heads): |
| if len(heads) == 0: |
| return |
| mask = torch.ones(self.num_attention_heads, self.attention_head_size) |
| heads = set(heads) - self.pruned_heads |
| for head in heads: |
| |
| head = head - sum(1 if h < head else 0 for h in self.pruned_heads) |
| mask[head] = 0 |
| mask = mask.view(-1).contiguous().eq(1) |
| index = torch.arange(len(mask))[mask].long() |
|
|
| |
| self.query = prune_linear_layer(self.query, index) |
| self.key = prune_linear_layer(self.key, index) |
| self.value = prune_linear_layer(self.value, index) |
| self.dense = prune_linear_layer(self.dense, index, dim=1) |
|
|
| |
| self.num_attention_heads = self.num_attention_heads - len(heads) |
| self.all_head_size = self.attention_head_size * self.num_attention_heads |
| self.pruned_heads = self.pruned_heads.union(heads) |
|
|
| def forward(self, input_ids, attention_mask=None, head_mask=None): |
| mixed_query_layer = self.query(input_ids) |
| mixed_key_layer = self.key(input_ids) |
| mixed_value_layer = self.value(input_ids) |
|
|
| query_layer = self.transpose_for_scores(mixed_query_layer) |
| key_layer = self.transpose_for_scores(mixed_key_layer) |
| value_layer = self.transpose_for_scores(mixed_value_layer) |
|
|
| |
| attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
| attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
| if attention_mask is not None: |
| |
| attention_scores = attention_scores + attention_mask |
|
|
| |
| attention_probs = nn.Softmax(dim=-1)(attention_scores) |
|
|
| |
| |
| attention_probs = self.dropout(attention_probs) |
|
|
| |
| if head_mask is not None: |
| attention_probs = attention_probs * head_mask |
|
|
| context_layer = torch.matmul(attention_probs, value_layer) |
|
|
| context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
|
|
| |
| w = ( |
| self.dense.weight.t() |
| .view(self.num_attention_heads, self.attention_head_size, self.hidden_size) |
| .to(context_layer.dtype) |
| ) |
| b = self.dense.bias.to(context_layer.dtype) |
|
|
| projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b |
| projected_context_layer_dropout = self.dropout(projected_context_layer) |
| layernormed_context_layer = self.LayerNorm(input_ids + projected_context_layer_dropout) |
| return (layernormed_context_layer, attention_probs) if self.output_attentions else (layernormed_context_layer,) |
|
|
|
|
| class AlbertLayer(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
|
|
| self.config = config |
| self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.attention = AlbertAttention(config) |
| self.ffn = nn.Linear(config.hidden_size, config.intermediate_size) |
| self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size) |
| self.activation = ACT2FN[config.hidden_act] |
|
|
| def forward(self, hidden_states, attention_mask=None, head_mask=None): |
| attention_output = self.attention(hidden_states, attention_mask, head_mask) |
| ffn_output = self.ffn(attention_output[0]) |
| ffn_output = self.activation(ffn_output) |
| ffn_output = self.ffn_output(ffn_output) |
| hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0]) |
|
|
| return (hidden_states,) + attention_output[1:] |
|
|
|
|
| class AlbertLayerGroup(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
|
|
| self.output_attentions = config.output_attentions |
| self.output_hidden_states = config.output_hidden_states |
| self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)]) |
|
|
| def forward(self, hidden_states, attention_mask=None, head_mask=None): |
| layer_hidden_states = () |
| layer_attentions = () |
|
|
| for layer_index, albert_layer in enumerate(self.albert_layers): |
| layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index]) |
| hidden_states = layer_output[0] |
|
|
| if self.output_attentions: |
| layer_attentions = layer_attentions + (layer_output[1],) |
|
|
| if self.output_hidden_states: |
| layer_hidden_states = layer_hidden_states + (hidden_states,) |
|
|
| outputs = (hidden_states,) |
| if self.output_hidden_states: |
| outputs = outputs + (layer_hidden_states,) |
| if self.output_attentions: |
| outputs = outputs + (layer_attentions,) |
| return outputs |
|
|
|
|
| class AlbertTransformer(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
|
|
| self.config = config |
| self.output_attentions = config.output_attentions |
| self.output_hidden_states = config.output_hidden_states |
| self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size) |
| self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)]) |
|
|
| def forward(self, hidden_states, attention_mask=None, head_mask=None): |
| hidden_states = self.embedding_hidden_mapping_in(hidden_states) |
|
|
| all_attentions = () |
|
|
| if self.output_hidden_states: |
| all_hidden_states = (hidden_states,) |
|
|
| for i in range(self.config.num_hidden_layers): |
| |
| layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups) |
|
|
| |
| group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups)) |
|
|
| layer_group_output = self.albert_layer_groups[group_idx]( |
| hidden_states, |
| attention_mask, |
| head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group], |
| ) |
| hidden_states = layer_group_output[0] |
|
|
| if self.output_attentions: |
| all_attentions = all_attentions + layer_group_output[-1] |
|
|
| if self.output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| outputs = (hidden_states,) |
| if self.output_hidden_states: |
| outputs = outputs + (all_hidden_states,) |
| if self.output_attentions: |
| outputs = outputs + (all_attentions,) |
| return outputs |
|
|
|
|
| class AlbertPreTrainedModel(PreTrainedModel): |
| """ An abstract class to handle weights initialization and |
| a simple interface for downloading and loading pretrained models. |
| """ |
|
|
| config_class = AlbertConfig |
| pretrained_model_archive_map = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP |
| base_model_prefix = "albert" |
|
|
| def _init_weights(self, module): |
| """ Initialize the weights. |
| """ |
| if isinstance(module, (nn.Linear, nn.Embedding)): |
| |
| |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if isinstance(module, (nn.Linear)) and module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
|
|
|
|
| ALBERT_START_DOCSTRING = r""" |
| |
| This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. |
| Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general |
| usage and behavior. |
| |
| Args: |
| config (:class:`~transformers.AlbertConfig`): Model configuration class with all the parameters of the model. |
| Initializing with a config file does not load the weights associated with the model, only the configuration. |
| Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. |
| """ |
|
|
| ALBERT_INPUTS_DOCSTRING = r""" |
| Args: |
| input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): |
| Indices of input sequence tokens in the vocabulary. |
| |
| Indices can be obtained using :class:`transformers.AlbertTokenizer`. |
| See :func:`transformers.PreTrainedTokenizer.encode` and |
| :func:`transformers.PreTrainedTokenizer.encode_plus` for details. |
| |
| `What are input IDs? <../glossary.html#input-ids>`__ |
| attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): |
| Mask to avoid performing attention on padding token indices. |
| Mask values selected in ``[0, 1]``: |
| ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. |
| |
| `What are attention masks? <../glossary.html#attention-mask>`__ |
| token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): |
| Segment token indices to indicate first and second portions of the inputs. |
| Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` |
| corresponds to a `sentence B` token |
| |
| `What are token type IDs? <../glossary.html#token-type-ids>`_ |
| position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): |
| Indices of positions of each input sequence tokens in the position embeddings. |
| Selected in the range ``[0, config.max_position_embeddings - 1]``. |
| |
| `What are position IDs? <../glossary.html#position-ids>`_ |
| head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): |
| Mask to nullify selected heads of the self-attention modules. |
| Mask values selected in ``[0, 1]``: |
| :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**. |
| input_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): |
| Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. |
| This is useful if you want more control over how to convert `input_ids` indices into associated vectors |
| than the model's internal embedding lookup matrix. |
| """ |
|
|
|
|
| @add_start_docstrings( |
| "The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top.", |
| ALBERT_START_DOCSTRING, |
| ) |
| class AlbertModel(AlbertPreTrainedModel): |
|
|
| config_class = AlbertConfig |
| pretrained_model_archive_map = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP |
| load_tf_weights = load_tf_weights_in_albert |
| base_model_prefix = "albert" |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.config = config |
| self.embeddings = AlbertEmbeddings(config) |
| self.encoder = AlbertTransformer(config) |
| self.pooler = nn.Linear(config.hidden_size, config.hidden_size) |
| self.pooler_activation = nn.Tanh() |
|
|
| self.init_weights() |
|
|
| def get_input_embeddings(self): |
| return self.embeddings.word_embeddings |
|
|
| def set_input_embeddings(self, value): |
| self.embeddings.word_embeddings = value |
|
|
| def _resize_token_embeddings(self, new_num_tokens): |
| old_embeddings = self.embeddings.word_embeddings |
| new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) |
| self.embeddings.word_embeddings = new_embeddings |
| return self.embeddings.word_embeddings |
|
|
| def _prune_heads(self, heads_to_prune): |
| """ Prunes heads of the model. |
| heads_to_prune: dict of {layer_num: list of heads to prune in this layer} |
| ALBERT has a different architecture in that its layers are shared across groups, which then has inner groups. |
| If an ALBERT model has 12 hidden layers and 2 hidden groups, with two inner groups, there |
| is a total of 4 different layers. |
| |
| These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer, |
| while [2,3] correspond to the two inner groups of the second hidden layer. |
| |
| Any layer with in index other than [0,1,2,3] will result in an error. |
| See base class PreTrainedModel for more information about head pruning |
| """ |
| for layer, heads in heads_to_prune.items(): |
| group_idx = int(layer / self.config.inner_group_num) |
| inner_group_idx = int(layer - group_idx * self.config.inner_group_num) |
| self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads) |
|
|
| @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) |
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| ): |
| r""" |
| Return: |
| :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs: |
| last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): |
| Sequence of hidden-states at the output of the last layer of the model. |
| pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`): |
| Last layer hidden-state of the first token of the sequence (classification token) |
| further processed by a Linear layer and a Tanh activation function. The Linear |
| layer weights are trained from the next sentence prediction (classification) |
| objective during pre-training. |
| |
| This output is usually *not* a good summary |
| of the semantic content of the input, you're often better with averaging or pooling |
| the sequence of hidden-states for the whole input sequence. |
| hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): |
| Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) |
| of shape :obj:`(batch_size, sequence_length, hidden_size)`. |
| |
| Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
| attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): |
| Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape |
| :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. |
| |
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
| heads. |
| |
| Example:: |
| |
| from transformers import AlbertModel, AlbertTokenizer |
| import torch |
| |
| tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2') |
| model = AlbertModel.from_pretrained('albert-base-v2') |
| input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 |
| outputs = model(input_ids) |
| last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple |
| |
| """ |
|
|
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
| elif input_ids is not None: |
| input_shape = input_ids.size() |
| elif inputs_embeds is not None: |
| input_shape = inputs_embeds.size()[:-1] |
| else: |
| raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones(input_shape, device=device) |
| if token_type_ids is None: |
| token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) |
|
|
| extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
| extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) |
| extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 |
| if head_mask is not None: |
| if head_mask.dim() == 1: |
| head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) |
| head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) |
| elif head_mask.dim() == 2: |
| head_mask = ( |
| head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) |
| ) |
| head_mask = head_mask.to( |
| dtype=next(self.parameters()).dtype |
| ) |
| else: |
| head_mask = [None] * self.config.num_hidden_layers |
|
|
| embedding_output = self.embeddings( |
| input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds |
| ) |
| encoder_outputs = self.encoder(embedding_output, extended_attention_mask, head_mask=head_mask) |
|
|
| sequence_output = encoder_outputs[0] |
|
|
| pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) |
|
|
| outputs = (sequence_output, pooled_output) + encoder_outputs[ |
| 1: |
| ] |
| return outputs |
|
|
|
|
| class AlbertMLMHead(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
|
|
| self.LayerNorm = nn.LayerNorm(config.embedding_size) |
| self.bias = nn.Parameter(torch.zeros(config.vocab_size)) |
| self.dense = nn.Linear(config.hidden_size, config.embedding_size) |
| self.decoder = nn.Linear(config.embedding_size, config.vocab_size) |
| self.activation = ACT2FN[config.hidden_act] |
|
|
| |
| self.decoder.bias = self.bias |
|
|
| def forward(self, hidden_states): |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.activation(hidden_states) |
| hidden_states = self.LayerNorm(hidden_states) |
| hidden_states = self.decoder(hidden_states) |
|
|
| prediction_scores = hidden_states + self.bias |
|
|
| return prediction_scores |
|
|
|
|
| @add_start_docstrings( |
| "Albert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, |
| ) |
| class AlbertForMaskedLM(AlbertPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.albert = AlbertModel(config) |
| self.predictions = AlbertMLMHead(config) |
|
|
| self.init_weights() |
| self.tie_weights() |
|
|
| def tie_weights(self): |
| self._tie_or_clone_weights(self.predictions.decoder, self.albert.embeddings.word_embeddings) |
|
|
| def get_output_embeddings(self): |
| return self.predictions.decoder |
|
|
| @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) |
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| masked_lm_labels=None, |
| ): |
| r""" |
| masked_lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): |
| Labels for computing the masked language modeling loss. |
| Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) |
| Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with |
| labels in ``[0, ..., config.vocab_size]`` |
| |
| Returns: |
| :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs: |
| loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: |
| Masked language modeling loss. |
| prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) |
| Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
| hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): |
| Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) |
| of shape :obj:`(batch_size, sequence_length, hidden_size)`. |
| |
| Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
| attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): |
| Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape |
| :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. |
| |
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
| heads. |
| |
| Example:: |
| |
| from transformers import AlbertTokenizer, AlbertForMaskedLM |
| import torch |
| |
| tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2') |
| model = AlbertForMaskedLM.from_pretrained('albert-base-v2') |
| input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 |
| outputs = model(input_ids, masked_lm_labels=input_ids) |
| loss, prediction_scores = outputs[:2] |
| |
| """ |
| outputs = self.albert( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| ) |
| sequence_outputs = outputs[0] |
|
|
| prediction_scores = self.predictions(sequence_outputs) |
|
|
| outputs = (prediction_scores,) + outputs[2:] |
| if masked_lm_labels is not None: |
| loss_fct = CrossEntropyLoss() |
| masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) |
| outputs = (masked_lm_loss,) + outputs |
|
|
| return outputs |
|
|
|
|
| @add_start_docstrings( |
| """Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of |
| the pooled output) e.g. for GLUE tasks. """, |
| ALBERT_START_DOCSTRING, |
| ) |
| class AlbertForSequenceClassification(AlbertPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
|
|
| self.albert = AlbertModel(config) |
| self.dropout = nn.Dropout(config.classifier_dropout_prob) |
| self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) |
|
|
| self.init_weights() |
|
|
| @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) |
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| labels=None, |
| ): |
| r""" |
| labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): |
| Labels for computing the sequence classification/regression loss. |
| Indices should be in ``[0, ..., config.num_labels - 1]``. |
| If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), |
| If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). |
| |
| Returns: |
| :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs: |
| loss: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: |
| Classification (or regression if config.num_labels==1) loss. |
| logits ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)`` |
| Classification (or regression if config.num_labels==1) scores (before SoftMax). |
| hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): |
| Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) |
| of shape :obj:`(batch_size, sequence_length, hidden_size)`. |
| |
| Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
| attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): |
| Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape |
| :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. |
| |
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
| heads. |
| |
| Examples:: |
| |
| from transformers import AlbertTokenizer, AlbertForSequenceClassification |
| import torch |
| |
| tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2') |
| model = AlbertForSequenceClassification.from_pretrained('albert-base-v2') |
| input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 |
| labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 |
| outputs = model(input_ids, labels=labels) |
| loss, logits = outputs[:2] |
| |
| """ |
|
|
| outputs = self.albert( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| ) |
|
|
| pooled_output = outputs[1] |
|
|
| pooled_output = self.dropout(pooled_output) |
| logits = self.classifier(pooled_output) |
|
|
| outputs = (logits,) + outputs[2:] |
|
|
| if labels is not None: |
| if self.num_labels == 1: |
| |
| loss_fct = MSELoss() |
| loss = loss_fct(logits.view(-1), labels.view(-1)) |
| else: |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| outputs = (loss,) + outputs |
|
|
| return outputs |
|
|
|
|
| @add_start_docstrings( |
| """Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of |
| the hidden-states output to compute `span start logits` and `span end logits`). """, |
| ALBERT_START_DOCSTRING, |
| ) |
| class AlbertForQuestionAnswering(AlbertPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
|
|
| self.albert = AlbertModel(config) |
| self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) |
|
|
| self.init_weights() |
|
|
| @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) |
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| start_positions=None, |
| end_positions=None, |
| ): |
| r""" |
| start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): |
| Labels for position (index) of the start of the labelled span for computing the token classification loss. |
| Positions are clamped to the length of the sequence (`sequence_length`). |
| Position outside of the sequence are not taken into account for computing the loss. |
| end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): |
| Labels for position (index) of the end of the labelled span for computing the token classification loss. |
| Positions are clamped to the length of the sequence (`sequence_length`). |
| Position outside of the sequence are not taken into account for computing the loss. |
| |
| Returns: |
| :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs: |
| loss: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: |
| Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. |
| start_scores ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` |
| Span-start scores (before SoftMax). |
| end_scores: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` |
| Span-end scores (before SoftMax). |
| hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): |
| Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) |
| of shape :obj:`(batch_size, sequence_length, hidden_size)`. |
| |
| Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
| attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): |
| Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape |
| :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. |
| |
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
| heads. |
| |
| Examples:: |
| |
| # The checkpoint albert-base-v2 is not fine-tuned for question answering. Please see the |
| # examples/run_squad.py example to see how to fine-tune a model to a question answering task. |
| |
| from transformers import AlbertTokenizer, AlbertForQuestionAnswering |
| import torch |
| |
| tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2') |
| model = AlbertForQuestionAnswering.from_pretrained('albert-base-v2') |
| question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" |
| input_dict = tokenizer.encode_plus(question, text, return_tensors='pt') |
| start_scores, end_scores = model(**input_dict) |
| |
| """ |
|
|
| outputs = self.albert( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| ) |
|
|
| sequence_output = outputs[0] |
|
|
| logits = self.qa_outputs(sequence_output) |
| start_logits, end_logits = logits.split(1, dim=-1) |
| start_logits = start_logits.squeeze(-1) |
| end_logits = end_logits.squeeze(-1) |
|
|
| outputs = (start_logits, end_logits,) + outputs[2:] |
| if start_positions is not None and end_positions is not None: |
| |
| if len(start_positions.size()) > 1: |
| start_positions = start_positions.squeeze(-1) |
| if len(end_positions.size()) > 1: |
| end_positions = end_positions.squeeze(-1) |
| |
| ignored_index = start_logits.size(1) |
| start_positions.clamp_(0, ignored_index) |
| end_positions.clamp_(0, ignored_index) |
|
|
| loss_fct = CrossEntropyLoss(ignore_index=ignored_index) |
| start_loss = loss_fct(start_logits, start_positions) |
| end_loss = loss_fct(end_logits, end_positions) |
| total_loss = (start_loss + end_loss) / 2 |
| outputs = (total_loss,) + outputs |
|
|
| return outputs |
|
|