MPTR_AutoT / prompt_model_factory.py
mcps5601
Add application files
826c825
from typing import Optional
from transformers import BertModel
from transformers.models.bert.modeling_bert import (
BertPreTrainedModel,
BertOnlyMLMHead,
)
import torch
class BertForPromptFinetuning(BertPreTrainedModel):
def __init__(self, config, use_multi_label_words: bool = False):
super().__init__(config)
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
# Initialize weights and apply final processing
self.init_weights()
self.label_word_ids = None
self.use_multi_label_words = use_multi_label_words
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
mask_pos: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = False,
output_attentions: Optional[bool] = False,
):
if mask_pos is not None:
mask_pos = mask_pos.squeeze()
elif mask_pos is None:
raise ValueError("`mask_pos` should be assigned!")
# Encode everything
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
)
# Get <mask> token representation
sequence_output = outputs[0]
sequence_mask_output = sequence_output[
torch.arange(sequence_output.size(0)), mask_pos
]
# Logits over vocabulary tokens
# prediction_mask_scores.shape: [batch_size, vocab_size]
prediction_mask_scores = self.cls(sequence_mask_output)
# Return logits for each label
logits = []
if self.use_multi_label_words:
for label_id in self.label_word_ids:
one_label_logits = []
# multiple ids in one label_id
for id in label_id:
one_label_word_logits = prediction_mask_scores[:, id]
one_label_logits.append(one_label_word_logits.unsqueeze(-1))
# one_label_logits: (bs, num_label_words)
one_label_logits = torch.cat(one_label_logits, -1)
# Get the max logits to choose the label word
logits.append(torch.max(one_label_logits, dim=1, keepdim=True)[0])
else:
for label_id in range(len(self.label_word_ids)):
logits.append(
prediction_mask_scores[:, self.label_word_ids[label_id]].unsqueeze(
-1
)
)
# logits.shape: [batch_size, num_classes]
logits = torch.sigmoid(torch.cat(logits, -1))
loss = None
if labels is not None:
loss_fct = torch.nn.BCELoss()
loss = loss_fct(logits, labels.float())
output = (logits, outputs.hidden_states) if output_hidden_states else (logits,)
output = (output + (outputs.attentions)) if output_attentions else output
return ((loss,) + output) if loss is not None else output