| | import torch |
| | import torch.nn as nn |
| | from typing import Optional, List, Union, Tuple |
| | from transformers import ( |
| | PretrainedConfig, |
| | PreTrainedModel, |
| | AutoTokenizer, |
| | AutoConfig, |
| | AutoModel, |
| | AutoModelForSequenceClassification |
| | ) |
| | from transformers.models.bert.modeling_bert import ( |
| | BertEmbeddings, |
| | BertEncoder, |
| | load_tf_weights_in_bert |
| | ) |
| | from transformers.modeling_outputs import ( |
| | BaseModelOutputWithPoolingAndCrossAttentions, |
| | SequenceClassifierOutput, |
| | MultipleChoiceModelOutput |
| | ) |
| |
|
| | from .configuration_bert import BertConfig |
| |
|
| |
|
| | class BertPreTrainedModel(PreTrainedModel): |
| |
|
| | config_class = BertConfig |
| | load_tf_weights = load_tf_weights_in_bert |
| | base_model_prefix = "bert" |
| | supports_gradient_checkpointing = True |
| |
|
| | def _init_weights(self, module): |
| | """Initialize the weights""" |
| | if isinstance(module, nn.Linear): |
| | |
| | |
| | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | elif isinstance(module, nn.Embedding): |
| | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| | if module.padding_idx is not None: |
| | module.weight.data[module.padding_idx].zero_() |
| | elif isinstance(module, nn.LayerNorm): |
| | module.bias.data.zero_() |
| | module.weight.data.fill_(1.0) |
| |
|
| |
|
| | class BertPooler(nn.Module): |
| |
|
| | def __init__(self, config): |
| | super().__init__() |
| | if config.affine: |
| | self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| | else: |
| | self.dense = nn.Identity() |
| | self.activation = nn.Tanh() |
| |
|
| | def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
| | |
| | |
| | mean_tensor = self.mean_pooling(hidden_states, attention_mask) |
| | pooled_output = self.dense(mean_tensor) |
| | pooled_output = self.activation(pooled_output) |
| | return pooled_output |
| |
|
| | def mean_pooling(self, token_embeddings, attention_mask): |
| | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
| |
|
| |
|
| | class BertModel(BertPreTrainedModel): |
| |
|
| | config_class = BertConfig |
| |
|
| | def __init__(self, config, add_pooling_layer=True): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | self.embeddings = BertEmbeddings(config) |
| | self.encoder = BertEncoder(config) |
| |
|
| | self.pooler = BertPooler(config) if add_pooling_layer else None |
| |
|
| | |
| | self.post_init() |
| |
|
| | def get_input_embeddings(self): |
| | return self.embeddings.word_embeddings |
| |
|
| | def set_input_embeddings(self, value): |
| | self.embeddings.word_embeddings = value |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | token_type_ids: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | head_mask: Optional[torch.Tensor] = None, |
| | inputs_embeds: Optional[torch.Tensor] = None, |
| | encoder_hidden_states: Optional[torch.Tensor] = None, |
| | encoder_attention_mask: Optional[torch.Tensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | if self.config.is_decoder: |
| | use_cache = use_cache if use_cache is not None else self.config.use_cache |
| | else: |
| | use_cache = False |
| |
|
| | if input_ids is not None and inputs_embeds is not None: |
| | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
| | elif input_ids is not None: |
| | self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) |
| | input_shape = input_ids.size() |
| | elif inputs_embeds is not None: |
| | input_shape = inputs_embeds.size()[:-1] |
| | else: |
| | raise ValueError("You have to specify either input_ids or inputs_embeds") |
| |
|
| | batch_size, seq_length = input_shape |
| | device = input_ids.device if input_ids is not None else inputs_embeds.device |
| |
|
| | |
| | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 |
| |
|
| | if attention_mask is None: |
| | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) |
| |
|
| | if token_type_ids is None: |
| | if hasattr(self.embeddings, "token_type_ids"): |
| | buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] |
| | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) |
| | token_type_ids = buffered_token_type_ids_expanded |
| | else: |
| | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) |
| |
|
| | |
| | |
| | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) |
| |
|
| | |
| | |
| | if self.config.is_decoder and encoder_hidden_states is not None: |
| | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() |
| | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) |
| | if encoder_attention_mask is None: |
| | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) |
| | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
| | else: |
| | encoder_extended_attention_mask = None |
| |
|
| | |
| | |
| | |
| | |
| | |
| | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
| |
|
| | embedding_output = self.embeddings( |
| | input_ids=input_ids, |
| | position_ids=position_ids, |
| | token_type_ids=token_type_ids, |
| | inputs_embeds=inputs_embeds, |
| | past_key_values_length=past_key_values_length, |
| | ) |
| | encoder_outputs = self.encoder( |
| | embedding_output, |
| | attention_mask=extended_attention_mask, |
| | head_mask=head_mask, |
| | encoder_hidden_states=encoder_hidden_states, |
| | encoder_attention_mask=encoder_extended_attention_mask, |
| | past_key_values=past_key_values, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| | sequence_output = encoder_outputs[0] |
| | pooled_output = self.pooler(sequence_output, attention_mask) if self.pooler is not None else None |
| |
|
| | if not return_dict: |
| | return (sequence_output, pooled_output) + encoder_outputs[1:] |
| |
|
| | return BaseModelOutputWithPoolingAndCrossAttentions( |
| | last_hidden_state=sequence_output, |
| | pooler_output=pooled_output, |
| | past_key_values=encoder_outputs.past_key_values, |
| | hidden_states=encoder_outputs.hidden_states, |
| | attentions=encoder_outputs.attentions, |
| | cross_attentions=encoder_outputs.cross_attentions, |
| | ) |
| |
|
| |
|
| | class BertForSequenceClassification(BertPreTrainedModel): |
| |
|
| | config_class = BertConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.num_labels = config.num_labels |
| | self.config = config |
| |
|
| | self.bert = BertModel(config) |
| | classifier_dropout = ( |
| | config.classifier_dropout |
| | if config.classifier_dropout is not None |
| | else config.hidden_dropout_prob |
| | ) |
| | self.dropout = nn.Dropout(classifier_dropout) |
| | self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | token_type_ids: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | head_mask: Optional[torch.Tensor] = None, |
| | inputs_embeds: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.Tensor] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: |
| | r""" |
| | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
| | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
| | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| | """ |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | outputs = self.bert( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids, |
| | position_ids=position_ids, |
| | head_mask=head_mask, |
| | inputs_embeds=inputs_embeds, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | pooled_output = outputs[1] |
| |
|
| | pooled_output = self.dropout(pooled_output) |
| | logits = self.classifier(pooled_output) |
| |
|
| | loss = None |
| | if labels is not None: |
| | if self.config.problem_type is None: |
| | if self.num_labels == 1: |
| | self.config.problem_type = "regression" |
| | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
| | self.config.problem_type = "single_label_classification" |
| | else: |
| | self.config.problem_type = "multi_label_classification" |
| |
|
| | if self.config.problem_type == "regression": |
| | loss_fct = nn.MSELoss() |
| | if self.num_labels == 1: |
| | loss = loss_fct(logits.squeeze(), labels.squeeze()) |
| | else: |
| | loss = loss_fct(logits, labels) |
| | elif self.config.problem_type == "single_label_classification": |
| | loss_fct = nn.CrossEntropyLoss() |
| | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| | elif self.config.problem_type == "multi_label_classification": |
| | loss_fct = nn.BCEWithLogitsLoss() |
| | loss = loss_fct(logits, labels) |
| | if not return_dict: |
| | output = (logits,) + outputs[2:] |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return SequenceClassifierOutput( |
| | loss=loss, |
| | logits=logits, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|
| |
|
| | class BertForMultipleChoice(BertPreTrainedModel): |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | self.bert = BertModel(config) |
| | classifier_dropout = ( |
| | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
| | ) |
| | self.dropout = nn.Dropout(classifier_dropout) |
| | self.classifier = nn.Linear(config.hidden_size, 1) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | token_type_ids: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | head_mask: Optional[torch.Tensor] = None, |
| | inputs_embeds: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.Tensor] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: |
| | |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| | num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] |
| |
|
| | input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None |
| | attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None |
| | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None |
| | position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None |
| | inputs_embeds = ( |
| | inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) |
| | if inputs_embeds is not None |
| | else None |
| | ) |
| |
|
| | outputs = self.bert( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids, |
| | position_ids=position_ids, |
| | head_mask=head_mask, |
| | inputs_embeds=inputs_embeds, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | pooled_output = outputs[1] |
| |
|
| | pooled_output = self.dropout(pooled_output) |
| | logits = self.classifier(pooled_output) |
| | reshaped_logits = logits.view(-1, num_choices) |
| |
|
| | loss = None |
| | if labels is not None: |
| | loss_fct = nn.CrossEntropyLoss() |
| | loss = loss_fct(reshaped_logits, labels) |
| |
|
| | if not return_dict: |
| | output = (reshaped_logits,) + outputs[2:] |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return MultipleChoiceModelOutput( |
| | loss=loss, |
| | logits=reshaped_logits, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |