| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """PyTorch PONET model.""" |
|
|
|
|
| import math |
| from dataclasses import dataclass |
| from typing import Optional, Tuple, Union |
|
|
| import torch |
| import torch.utils.checkpoint |
| from torch import nn |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
| from transformers.activations import ACT2FN |
| from transformers.modeling_outputs import ( |
| BaseModelOutput, |
| BaseModelOutputWithPooling, |
| SequenceClassifierOutput, |
| ) |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer |
| from transformers.utils import ( |
| ModelOutput, |
| add_code_sample_docstrings, |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| logging, |
| replace_return_docstrings, |
| ) |
| from .configuration_ponet import PoNetConfig |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| _CHECKPOINT_FOR_DOC = "ponet-base" |
| _CONFIG_FOR_DOC = "PoNetConfig" |
|
|
|
|
| PONET_PRETRAINED_MODEL_ARCHIVE_LIST = [ |
| "chtan/ponet-base-uncased", |
| |
| ] |
|
|
| |
| CLS_ID = 101 |
| EOS_ID = 102 |
|
|
|
|
| def segment_max(src, index, dim=1): |
| out = torch.zeros_like(src).scatter_reduce( |
| dim, index.unsqueeze(-1).expand_as(src), src, reduce="amax", include_self=False |
| ) |
| dummy = index.unsqueeze(-1).expand(*index.shape[:2], out.size(-1)) |
| return torch.gather(out, dim, dummy).to(dtype=src.dtype) |
|
|
|
|
| def get_segment_index(input_ids, cls_id=CLS_ID, eos_id=EOS_ID): |
| mask = (input_ids == cls_id).to(dtype=torch.long) + (input_ids == eos_id).to(dtype=torch.long) |
| mask = mask + torch.cat([torch.zeros_like(mask[:, 0:1]), mask[:, :-1]], dim=1) |
| num_segments = input_ids[:, :1] == cls_id |
| segment_idx = mask.cumsum(dim=1) |
| return torch.where(num_segments == 0, segment_idx, segment_idx - 1) |
|
|
|
|
| def get_token_type_mask(input_ids, cls_id=CLS_ID, eos_id=EOS_ID): |
| mask = (input_ids == cls_id) | (input_ids == eos_id) |
| return mask |
|
|
|
|
| def get_win_max(hidden_states, kernel_size=3): |
| m = nn.MaxPool1d(kernel_size, stride=1, padding=kernel_size // 2) |
| out = m(hidden_states.permute(0, 2, 1)).permute(0, 2, 1) |
| return out |
|
|
|
|
| |
| class PoNetEmbeddings(nn.Module): |
| """Construct the embeddings from word, position and token_type embeddings.""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) |
| self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
| self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) |
|
|
| |
| |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| |
| self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") |
| self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) |
| self.register_buffer( |
| "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False |
| ) |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| token_type_ids: Optional[torch.LongTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| past_key_values_length: int = 0, |
| ) -> torch.Tensor: |
| if input_ids is not None: |
| input_shape = input_ids.size() |
| else: |
| input_shape = inputs_embeds.size()[:-1] |
|
|
| seq_length = input_shape[1] |
|
|
| if position_ids is None: |
| position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] |
|
|
| |
| |
| |
| if token_type_ids is None: |
| if hasattr(self, "token_type_ids"): |
| buffered_token_type_ids = self.token_type_ids[:, :seq_length] |
| buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) |
| token_type_ids = buffered_token_type_ids_expanded |
| else: |
| token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.word_embeddings(input_ids) |
| token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
|
| embeddings = inputs_embeds + token_type_embeddings |
| if self.position_embedding_type == "absolute": |
| position_embeddings = self.position_embeddings(position_ids) |
| embeddings += position_embeddings |
| embeddings = self.LayerNorm(embeddings) |
| embeddings = self.dropout(embeddings) |
| return embeddings |
|
|
|
|
| class PoNetSelfAttention(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): |
| raise ValueError( |
| f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " |
| f"heads ({config.num_attention_heads})" |
| ) |
| self.clsgsepg = getattr(config, "clsgsepg", True) |
|
|
| self.num_attention_heads = config.num_attention_heads |
| self.attention_head_size = int(config.hidden_size / config.num_attention_heads) |
| self.all_head_size = self.num_attention_heads * self.attention_head_size |
|
|
| self.dense_local = nn.Linear(config.hidden_size, config.hidden_size) |
| self.dense_segment = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
| self.dense_q = nn.Linear(config.hidden_size, self.all_head_size) |
| self.dense_k = nn.Linear(config.hidden_size, self.all_head_size) |
| self.dense_o = nn.Linear(config.hidden_size, self.all_head_size) |
|
|
| def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: |
| new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) |
| x = x.view(new_x_shape) |
| return x.permute(0, 2, 1, 3) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| segment_index: torch.LongTensor, |
| token_type_mask: torch.LongTensor, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| output_attentions: Optional[bool] = False, |
| ) -> Tuple[torch.Tensor]: |
| context_layer_q = self.transpose_for_scores(self.dense_q(hidden_states)) |
| context_layer_k = self.transpose_for_scores(self.dense_k(hidden_states)) |
| context_layer_v = context_layer_k |
| context_layer_o = self.transpose_for_scores(self.dense_o(hidden_states)) |
|
|
| if attention_mask is not None: |
| _attention_mask = attention_mask.squeeze(1).unsqueeze(-1) < -1 |
|
|
| if attention_mask is not None: |
| context_layer_q.masked_fill_(_attention_mask, 0.0) |
| q = context_layer_q.sum(dim=-2) / torch.ones_like(_attention_mask).to( |
| dtype=context_layer_q.dtype |
| ).masked_fill(_attention_mask, 0.0).sum(dim=-2) |
| else: |
| q = context_layer_q.mean(dim=-2) |
| att = torch.einsum("bdh,bdlh -> bdl", q, context_layer_k) / math.sqrt(context_layer_q.shape[-1]) |
| if attention_mask is not None: |
| att = att + attention_mask.squeeze(1) |
| att_prob = att.softmax(dim=-1) |
| v = torch.einsum("bdlh,bdl->bdh", context_layer_v, att_prob) |
|
|
| context_layer_segment = self.dense_segment(hidden_states) |
| context_layer_local = self.dense_local(hidden_states) |
| if attention_mask is not None: |
| context_layer_local.masked_fill_(_attention_mask.squeeze(1), -10000) |
| context_layer_segment.masked_fill_(_attention_mask.squeeze(1), -10000) |
|
|
| if self.clsgsepg: |
| |
| context_layer_local = get_win_max( |
| context_layer_local.masked_fill(token_type_mask.unsqueeze(dim=-1), -10000) |
| ) |
| context_layer_segment = segment_max(context_layer_segment, index=segment_index) |
|
|
| context_layer_segment.masked_fill_(token_type_mask.unsqueeze(dim=-1), 0.0) |
| context_layer_local.masked_fill_(token_type_mask.unsqueeze(dim=-1), 0.0) |
| else: |
| context_layer_local = get_win_max(context_layer_local) |
| context_layer_segment = segment_max(context_layer_segment, index=segment_index) |
|
|
| context_layer_local = self.transpose_for_scores(context_layer_local) |
| context_layer_segment = self.transpose_for_scores(context_layer_segment) |
|
|
| context_layer = (v.unsqueeze(dim=-2) + context_layer_segment) * context_layer_o + context_layer_local |
| context_layer = context_layer.permute(0, 2, 1, 3).reshape(*hidden_states.shape[:2], -1) |
|
|
| if attention_mask is not None: |
| context_layer.masked_fill_(_attention_mask.squeeze(1), 0.0) |
|
|
| outputs = (context_layer, att_prob) if output_attentions else (context_layer,) |
| return outputs |
|
|
|
|
| |
| class PoNetSelfOutput(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
| def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.dropout(hidden_states) |
| hidden_states = self.LayerNorm(hidden_states + input_tensor) |
| return hidden_states |
|
|
|
|
| class PoNetAttention(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.self = PoNetSelfAttention(config) |
| self.output = PoNetSelfOutput(config) |
| self.pruned_heads = set() |
|
|
| def prune_heads(self, heads): |
| if len(heads) == 0: |
| return |
| heads, index = find_pruneable_heads_and_indices( |
| heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads |
| ) |
|
|
| |
| self.self.query = prune_linear_layer(self.self.query, index) |
| self.self.key = prune_linear_layer(self.self.key, index) |
| self.self.value = prune_linear_layer(self.self.value, index) |
| self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) |
|
|
| |
| self.self.num_attention_heads = self.self.num_attention_heads - len(heads) |
| self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads |
| self.pruned_heads = self.pruned_heads.union(heads) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| segment_index: torch.LongTensor, |
| token_type_mask: torch.LongTensor, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| output_attentions: Optional[bool] = False, |
| ) -> Tuple[torch.Tensor]: |
| self_outputs = self.self( |
| hidden_states, |
| segment_index, |
| token_type_mask, |
| attention_mask, |
| output_attentions, |
| ) |
| attention_output = self.output(self_outputs[0], hidden_states) |
| outputs = (attention_output,) + self_outputs[1:] |
| return outputs |
|
|
|
|
| |
| class PoNetIntermediate(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.hidden_size, config.intermediate_size) |
| if isinstance(config.hidden_act, str): |
| self.intermediate_act_fn = ACT2FN[config.hidden_act] |
| else: |
| self.intermediate_act_fn = config.hidden_act |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.intermediate_act_fn(hidden_states) |
| return hidden_states |
|
|
|
|
| |
| class PoNetOutput(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.intermediate_size, config.hidden_size) |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
| def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.dropout(hidden_states) |
| hidden_states = self.LayerNorm(hidden_states + input_tensor) |
| return hidden_states |
|
|
|
|
| class PoNetLayer(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.chunk_size_feed_forward = config.chunk_size_feed_forward |
| self.seq_len_dim = 1 |
| self.attention = PoNetAttention(config) |
|
|
| config.is_decoder = False |
| self.is_decoder = config.is_decoder |
|
|
| self.intermediate = PoNetIntermediate(config) |
| self.output = PoNetOutput(config) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| segment_index: torch.LongTensor, |
| token_type_mask: torch.LongTensor, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| output_attentions: Optional[bool] = False, |
| ) -> Tuple[torch.Tensor]: |
| |
| self_attention_outputs = self.attention( |
| hidden_states, |
| segment_index, |
| token_type_mask, |
| attention_mask, |
| output_attentions=output_attentions, |
| ) |
| attention_output = self_attention_outputs[0] |
|
|
| outputs = self_attention_outputs[1:] |
|
|
| layer_output = apply_chunking_to_forward( |
| self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output |
| ) |
| outputs = (layer_output,) + outputs |
|
|
| return outputs |
|
|
| def feed_forward_chunk(self, attention_output): |
| intermediate_output = self.intermediate(attention_output) |
| layer_output = self.output(intermediate_output, attention_output) |
| return layer_output |
|
|
|
|
| class PoNetEncoder(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.layer = nn.ModuleList([PoNetLayer(config) for _ in range(config.num_hidden_layers)]) |
| self.gradient_checkpointing = False |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| segment_index: torch.LongTensor, |
| token_type_mask: torch.LongTensor, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| output_attentions: Optional[bool] = False, |
| output_hidden_states: Optional[bool] = False, |
| return_dict: Optional[bool] = True, |
| ) -> Union[Tuple[torch.Tensor], BaseModelOutput]: |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attentions = () if output_attentions else None |
|
|
| for i, layer_module in enumerate(self.layer): |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if self.gradient_checkpointing and self.training: |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs, output_attentions) |
|
|
| return custom_forward |
|
|
| layer_outputs = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(layer_module), |
| hidden_states, |
| segment_index, |
| token_type_mask, |
| attention_mask, |
| ) |
| else: |
| layer_outputs = layer_module( |
| hidden_states, |
| segment_index, |
| token_type_mask, |
| attention_mask, |
| output_attentions, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
| if output_attentions: |
| all_self_attentions = all_self_attentions + (layer_outputs[1],) |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if not return_dict: |
| return tuple( |
| v |
| for v in [ |
| hidden_states, |
| all_hidden_states, |
| all_self_attentions, |
| ] |
| if v is not None |
| ) |
| return BaseModelOutput( |
| last_hidden_state=hidden_states, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attentions, |
| ) |
|
|
|
|
| |
| class PoNetPooler(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| self.activation = nn.Tanh() |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| |
| |
| first_token_tensor = hidden_states[:, 0] |
| pooled_output = self.dense(first_token_tensor) |
| pooled_output = self.activation(pooled_output) |
| return pooled_output |
|
|
|
|
| |
| class PoNetPredictionHeadTransform(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| if isinstance(config.hidden_act, str): |
| self.transform_act_fn = ACT2FN[config.hidden_act] |
| else: |
| self.transform_act_fn = config.hidden_act |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.transform_act_fn(hidden_states) |
| hidden_states = self.LayerNorm(hidden_states) |
| return hidden_states |
|
|
|
|
| |
| class PoNetLMPredictionHead(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.transform = PoNetPredictionHeadTransform(config) |
|
|
| |
| |
| self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| self.bias = nn.Parameter(torch.zeros(config.vocab_size)) |
|
|
| |
| self.decoder.bias = self.bias |
|
|
| def forward(self, hidden_states): |
| hidden_states = self.transform(hidden_states) |
| hidden_states = self.decoder(hidden_states) |
| return hidden_states |
|
|
|
|
| |
| class PoNetOnlyMLMHead(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.predictions = PoNetLMPredictionHead(config) |
|
|
| def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: |
| prediction_scores = self.predictions(sequence_output) |
| return prediction_scores |
|
|
|
|
| |
| class PoNetOnlyNSPHead(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.seq_relationship = nn.Linear(config.hidden_size, 2) |
|
|
| def forward(self, pooled_output): |
| seq_relationship_score = self.seq_relationship(pooled_output) |
| return seq_relationship_score |
|
|
|
|
| class PoNetOnlySSOHead(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.seq_relationship = nn.Linear(config.hidden_size, 3) |
|
|
| def forward(self, pooled_output): |
| seq_relationship_score = self.seq_relationship(pooled_output) |
| return seq_relationship_score |
|
|
|
|
| class PoNetPreTrainingHeads(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.predictions = PoNetLMPredictionHead(config) |
| self.seq_relationship = nn.Linear(config.hidden_size, 3) |
|
|
| def forward(self, sequence_output, pooled_output): |
| prediction_scores = self.predictions(sequence_output) |
| seq_relationship_score = self.seq_relationship(pooled_output) |
| return prediction_scores, seq_relationship_score |
|
|
|
|
| class PoNetPreTrainedModel(PreTrainedModel): |
| """ |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| models. |
| """ |
|
|
| config_class = PoNetConfig |
| base_model_prefix = "ponet" |
| supports_gradient_checkpointing = True |
| _keys_to_ignore_on_load_missing = [r"position_ids"] |
|
|
| def _init_weights(self, module): |
| """Initialize the weights""" |
| if isinstance(module, nn.Linear): |
| |
| |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| if isinstance(module, PoNetEncoder): |
| module.gradient_checkpointing = value |
|
|
|
|
| @dataclass |
| class PoNetForPreTrainingOutput(ModelOutput): |
| """ |
| Output type of [*PoNetForPreTraining*]. |
| |
| Args: |
| loss (*optional*, returned when *labels* is provided, *torch.FloatTensor* of shape *(1,)*): |
| Total loss as the sum of the masked language modeling loss and the next sequence prediction |
| (classification) loss. |
| mlm_loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): |
| Masked language modeling loss. |
| sso_loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): |
| sso loss. |
| prediction_logits (*torch.FloatTensor* of shape *(batch_size, sequence_length, config.vocab_size)*): |
| Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
| seq_relationship_logits (*torch.FloatTensor* of shape *(batch_size, 3)*): |
| Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation |
| before SoftMax). |
| hidden_states (*tuple(torch.FloatTensor)*, *optional*, returned when *output_hidden_states=True* is passed or when *config.output_hidden_states=True*): |
| Tuple of *torch.FloatTensor* (one for the output of the embeddings + one for the output of each layer) of |
| shape *(batch_size, sequence_length, hidden_size)*. |
| |
| Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
| attentions (*tuple(torch.FloatTensor)*, *optional*, returned when *output_attentions=True* is passed or when *config.output_attentions=True*): |
| Tuple of *torch.FloatTensor* (one for each layer) of shape *(batch_size, num_heads, sequence_length, |
| sequence_length)*. |
| |
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
| heads. |
| """ |
|
|
| loss: Optional[torch.FloatTensor] = None |
| mlm_loss: Optional[torch.FloatTensor] = None |
| sso_loss: Optional[torch.FloatTensor] = None |
| prediction_logits: torch.FloatTensor = None |
| seq_relationship_logits: torch.FloatTensor = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
| PONET_START_DOCSTRING = r""" |
| |
| This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| etc.) |
| |
| This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
| Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
| and behavior. |
| |
| Parameters: |
| config ([`PoNetConfig`]): Model configuration class with all the parameters of the model. |
| Initializing with a config file does not load the weights associated with the model, only the |
| configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
| """ |
|
|
| PONET_INPUTS_DOCSTRING = r""" |
| Args: |
| input_ids (`torch.LongTensor` of shape `({0})`): |
| Indices of input sequence tokens in the vocabulary. |
| |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| [What are input IDs?](../glossary#input-ids) |
| attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) |
| token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): |
| Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, |
| 1]`: |
| |
| - 0 corresponds to a *sentence A* token, |
| - 1 corresponds to a *sentence B* token. |
| |
| [What are token type IDs?](../glossary#token-type-ids) |
| position_ids (`torch.LongTensor` of shape `({0})`, *optional*): |
| Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
| config.max_position_embeddings - 1]`. |
| |
| [What are position IDs?](../glossary#position-ids) |
| inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
| is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
| model's internal embedding lookup matrix. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| """ |
|
|
|
|
| @add_start_docstrings( |
| "The bare PoNet Model transformer outputting raw hidden-states without any specific head on top.", |
| PONET_START_DOCSTRING, |
| ) |
| class PoNetModel(PoNetPreTrainedModel): |
| def __init__(self, config, add_pooling_layer=True): |
| super().__init__(config) |
| self.config = config |
|
|
| self.embeddings = PoNetEmbeddings(config) |
| self.encoder = PoNetEncoder(config) |
|
|
| self.pooler = PoNetPooler(config) if add_pooling_layer else None |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embeddings.word_embeddings |
|
|
| def set_input_embeddings(self, value): |
| self.embeddings.word_embeddings = value |
|
|
| def _prune_heads(self, heads_to_prune): |
| """ |
| Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base |
| class PreTrainedModel |
| """ |
| for layer, heads in heads_to_prune.items(): |
| self.encoder.layer[layer].attention.prune_heads(heads) |
|
|
| @add_start_docstrings_to_model_forward(PONET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) |
| @add_code_sample_docstrings( |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=BaseModelOutputWithPooling, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| token_type_ids: Optional[torch.Tensor] = None, |
| segment_ids: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
| elif input_ids is not None: |
| input_shape = input_ids.size() |
| elif inputs_embeds is not None: |
| input_shape = inputs_embeds.size()[:-1] |
| else: |
| raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
| batch_size, seq_length = input_shape |
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones(((batch_size, seq_length)), device=device) |
|
|
| if token_type_ids is None: |
| if hasattr(self.embeddings, "token_type_ids"): |
| buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] |
| buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) |
| token_type_ids = buffered_token_type_ids_expanded |
| else: |
| token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) |
|
|
| |
| |
| extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) |
|
|
| embedding_output = self.embeddings( |
| input_ids=input_ids, |
| position_ids=position_ids, |
| token_type_ids=token_type_ids, |
| inputs_embeds=inputs_embeds, |
| ) |
|
|
| segment_index = get_segment_index(input_ids) if segment_ids is None else segment_ids |
| token_type_mask = get_token_type_mask(input_ids) |
| encoder_outputs = self.encoder( |
| embedding_output, |
| segment_index, |
| token_type_mask, |
| attention_mask=extended_attention_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| sequence_output = encoder_outputs[0] |
| pooled_output = self.pooler(sequence_output) if self.pooler is not None else None |
|
|
| if not return_dict: |
| return (sequence_output, pooled_output) + encoder_outputs[1:] |
|
|
| return BaseModelOutputWithPooling( |
| last_hidden_state=sequence_output, |
| pooler_output=pooled_output, |
| hidden_states=encoder_outputs.hidden_states, |
| attentions=encoder_outputs.attentions, |
| ) |
|
|
|
|
| @add_start_docstrings( |
| """ |
| PoNet Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next |
| sentence prediction (classification)` head. |
| """, |
| PONET_START_DOCSTRING, |
| ) |
| class PoNetForPreTraining(PoNetPreTrainedModel): |
| _keys_to_ignore_on_load_missing = [r"position_ids", r"cls.predictions.decoder.bias"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.ponet = PoNetModel(config) |
| self.cls = PoNetPreTrainingHeads(config) |
|
|
| |
| self.post_init() |
|
|
| @add_start_docstrings_to_model_forward(PONET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) |
| @replace_return_docstrings(output_type=PoNetForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) |
| def forward( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| token_type_ids: Optional[torch.Tensor] = None, |
| segment_ids: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| sentence_structural_label: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple[torch.Tensor], PoNetForPreTrainingOutput]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., |
| config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), |
| the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` |
| sentence_structural_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for computing the sentence structural objective (classification) loss. Input should be a |
| sequence pair (see `input_ids` docstring) Indices should be in `[0, 1, 2]`: |
| |
| - 0 indicates sequence B is a continuation of sequence A, |
| - 1 indicates sequence A is a continuation of sequence B, |
| - 2 indicates sequence B is a random sequence. |
| kwargs (`Dict[str, any]`, optional, defaults to *{}*): |
| Used to hide legacy arguments that have been deprecated. |
| |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, PoNetForPreTraining |
| >>> import torch |
| |
| >>> tokenizer = AutoTokenizer.from_pretrained("ponet-base") |
| >>> model = PoNetForPreTraining.from_pretrained("ponet-base") |
| |
| >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") |
| >>> outputs = model(**inputs) |
| |
| >>> prediction_logits = outputs.prediction_logits |
| >>> seq_relationship_logits = outputs.seq_relationship_logits |
| ``` |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| outputs = self.ponet( |
| input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| segment_ids=segment_ids, |
| position_ids=position_ids, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| sequence_output, pooled_output = outputs[:2] |
| prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) |
|
|
| total_loss = None |
| masked_lm_loss = None |
| sso_loss = None |
| if labels is not None and sentence_structural_label is not None: |
| loss_fct = CrossEntropyLoss() |
| masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) |
| sso_loss = loss_fct(seq_relationship_score.view(-1, 3), sentence_structural_label.view(-1)) |
| total_loss = masked_lm_loss + sso_loss |
|
|
| if not return_dict: |
| output = (prediction_scores, seq_relationship_score) + outputs[2:] |
| return ((total_loss, masked_lm_loss, sso_loss) + output) if total_loss is not None else output |
|
|
| return PoNetForPreTrainingOutput( |
| loss=total_loss, |
| mlm_loss=masked_lm_loss, |
| sso_loss=sso_loss, |
| prediction_logits=prediction_scores, |
| seq_relationship_logits=seq_relationship_score, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| @add_start_docstrings( |
| """ |
| PoNet Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled |
| output) e.g. for GLUE tasks. |
| """, |
| PONET_START_DOCSTRING, |
| ) |
| class PoNetForSequenceClassification(PoNetPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.config = config |
|
|
| self.ponet = PoNetModel(config) |
| classifier_dropout = ( |
| config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
| ) |
| self.dropout = nn.Dropout(classifier_dropout) |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
| |
| self.post_init() |
|
|
| @add_start_docstrings_to_model_forward(PONET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) |
| @add_code_sample_docstrings( |
| output_type=SequenceClassifierOutput, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| token_type_ids: Optional[torch.Tensor] = None, |
| segment_ids: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
| config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
| `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| outputs = self.ponet( |
| input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| segment_ids=segment_ids, |
| position_ids=position_ids, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| pooled_output = outputs[1] |
|
|
| pooled_output = self.dropout(pooled_output) |
| logits = self.classifier(pooled_output) |
|
|
| loss = None |
| if labels is not None: |
| if self.config.problem_type is None: |
| if self.num_labels == 1: |
| self.config.problem_type = "regression" |
| elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
| self.config.problem_type = "single_label_classification" |
| else: |
| self.config.problem_type = "multi_label_classification" |
|
|
| if self.config.problem_type == "regression": |
| loss_fct = MSELoss() |
| if self.num_labels == 1: |
| loss = loss_fct(logits.squeeze(), labels.squeeze()) |
| else: |
| loss = loss_fct(logits, labels) |
| elif self.config.problem_type == "single_label_classification": |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| elif self.config.problem_type == "multi_label_classification": |
| loss_fct = BCEWithLogitsLoss() |
| loss = loss_fct(logits, labels) |
| if not return_dict: |
| output = (logits,) + outputs[2:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|