| | from typing import List, Optional, Tuple, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
| | from transformers.models.bert.modeling_bert import BertModel as TransformersBertModel |
| | from transformers.models.bert.modeling_bert import BertForMaskedLM as TransformersBertForMaskedLM |
| | from transformers.models.bert.modeling_bert import BertForPreTraining as TransformersBertForPreTraining |
| | from transformers.models.bert.modeling_bert import BertPreTrainedModel |
| | from transformers.modeling_outputs import SequenceClassifierOutput |
| |
|
| |
|
| | class BertModel(TransformersBertModel): |
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | class BertForMaskedLM(TransformersBertForMaskedLM): |
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | class BertForPreTraining(TransformersBertForPreTraining): |
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| |
|
| |
|
| | class DNABertForSequenceClassification(BertPreTrainedModel): |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.num_labels = config.num_labels |
| | self.config = config |
| |
|
| | self.bert = BertModel(config) |
| | classifier_dropout = ( |
| | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
| | ) |
| | self.dropout = nn.Dropout(classifier_dropout) |
| | self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | token_type_ids: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | head_mask: Optional[torch.Tensor] = None, |
| | inputs_embeds: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.Tensor] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: |
| | r""" |
| | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
| | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
| | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| | """ |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | |
| | batch_size, seq_len = input_ids.shape |
| | if seq_len > 512: |
| | assert seq_len % 512 == 0, "seq_len should be a multiple of 512" |
| | |
| | input_ids = input_ids.view(-1, 512) |
| | attention_mask = attention_mask.view(-1, 512) if attention_mask is not None else None |
| | token_type_ids = token_type_ids.view(-1, 512) if token_type_ids is not None else None |
| | position_ids = None |
| |
|
| | outputs = self.bert( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids, |
| | position_ids=position_ids, |
| | head_mask=head_mask, |
| | inputs_embeds=inputs_embeds, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | pooled_output = outputs[1] |
| |
|
| | if seq_len > 512: |
| | |
| | pooled_output = pooled_output.view(batch_size, -1, pooled_output.shape[-1]) |
| | |
| | pooled_output = torch.mean(pooled_output, dim=1) |
| |
|
| | pooled_output = self.dropout(pooled_output) |
| | logits = self.classifier(pooled_output) |
| |
|
| | loss = None |
| | if labels is not None: |
| | if self.config.problem_type is None: |
| | if self.num_labels == 1: |
| | self.config.problem_type = "regression" |
| | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
| | self.config.problem_type = "single_label_classification" |
| | else: |
| | self.config.problem_type = "multi_label_classification" |
| |
|
| | if self.config.problem_type == "regression": |
| | loss_fct = MSELoss() |
| | if self.num_labels == 1: |
| | loss = loss_fct(logits.squeeze(), labels.squeeze()) |
| | else: |
| | loss = loss_fct(logits, labels) |
| | elif self.config.problem_type == "single_label_classification": |
| | loss_fct = CrossEntropyLoss() |
| | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| | elif self.config.problem_type == "multi_label_classification": |
| | loss_fct = BCEWithLogitsLoss() |
| | loss = loss_fct(logits, labels) |
| | if not return_dict: |
| | output = (logits,) + outputs[2:] |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return SequenceClassifierOutput( |
| | loss=loss, |
| | logits=logits, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |