|
|
from typing import Optional |
|
|
from transformers import BertModel |
|
|
from transformers.models.bert.modeling_bert import ( |
|
|
BertPreTrainedModel, |
|
|
BertOnlyMLMHead, |
|
|
) |
|
|
import torch |
|
|
|
|
|
|
|
|
class BertForPromptFinetuning(BertPreTrainedModel): |
|
|
def __init__(self, config, use_multi_label_words: bool = False): |
|
|
super().__init__(config) |
|
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
|
self.cls = BertOnlyMLMHead(config) |
|
|
|
|
|
self.init_weights() |
|
|
|
|
|
self.label_word_ids = None |
|
|
self.use_multi_label_words = use_multi_label_words |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
mask_pos: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
output_hidden_states: Optional[bool] = False, |
|
|
output_attentions: Optional[bool] = False, |
|
|
): |
|
|
if mask_pos is not None: |
|
|
mask_pos = mask_pos.squeeze() |
|
|
elif mask_pos is None: |
|
|
raise ValueError("`mask_pos` should be assigned!") |
|
|
|
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
output_hidden_states=output_hidden_states, |
|
|
output_attentions=output_attentions, |
|
|
) |
|
|
|
|
|
|
|
|
sequence_output = outputs[0] |
|
|
sequence_mask_output = sequence_output[ |
|
|
torch.arange(sequence_output.size(0)), mask_pos |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
prediction_mask_scores = self.cls(sequence_mask_output) |
|
|
|
|
|
|
|
|
logits = [] |
|
|
if self.use_multi_label_words: |
|
|
for label_id in self.label_word_ids: |
|
|
one_label_logits = [] |
|
|
|
|
|
for id in label_id: |
|
|
one_label_word_logits = prediction_mask_scores[:, id] |
|
|
one_label_logits.append(one_label_word_logits.unsqueeze(-1)) |
|
|
|
|
|
one_label_logits = torch.cat(one_label_logits, -1) |
|
|
|
|
|
logits.append(torch.max(one_label_logits, dim=1, keepdim=True)[0]) |
|
|
|
|
|
else: |
|
|
for label_id in range(len(self.label_word_ids)): |
|
|
logits.append( |
|
|
prediction_mask_scores[:, self.label_word_ids[label_id]].unsqueeze( |
|
|
-1 |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
logits = torch.sigmoid(torch.cat(logits, -1)) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = torch.nn.BCELoss() |
|
|
loss = loss_fct(logits, labels.float()) |
|
|
|
|
|
output = (logits, outputs.hidden_states) if output_hidden_states else (logits,) |
|
|
output = (output + (outputs.attentions)) if output_attentions else output |
|
|
|
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|