| | from torch import nn |
| | from transformers import PreTrainedModel, PretrainedConfig |
| | from transformers import BertModel, BertConfig |
| | from transformers import AutoModelForTokenClassification, AutoConfig |
| | from torchcrf import CRF |
| |
|
| | class BERT_CRF_Config(PretrainedConfig): |
| | model_type = "BERT_CRF" |
| |
|
| | def __init__(self, **kwarg): |
| | super().__init__(**kwarg) |
| | self.model_name = "BERT_CRF" |
| |
|
| |
|
| | class BERT_CRF(PreTrainedModel): |
| | config_class = BERT_CRF_Config |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | bert_config = BertConfig.from_pretrained(config.bert_name) |
| |
|
| | bert_config.output_attentions = True |
| | bert_config.output_hidden_states = True |
| |
|
| | self.bert = BertModel.from_pretrained(config.bert_name, config=bert_config) |
| |
|
| | self.dropout = nn.Dropout(p=0.5) |
| |
|
| | self.linear = nn.Linear( |
| | self.bert.config.hidden_size, config.num_labels) |
| |
|
| | self.crf = CRF(config.num_labels, batch_first=True) |
| |
|
| | def forward(self, input_ids, token_type_ids, attention_mask, labels, labels_mask): |
| |
|
| | last_hidden_layer = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)[ |
| | 'last_hidden_state'] |
| |
|
| | last_hidden_layer = self.dropout(last_hidden_layer) |
| |
|
| | logits = self.linear(last_hidden_layer) |
| |
|
| | batch_size = logits.shape[0] |
| |
|
| | output_tags = [] |
| |
|
| | if labels is not None: |
| | loss = 0 |
| |
|
| | for seq_logits, seq_labels, seq_mask in zip(logits, labels, labels_mask): |
| | |
| | |
| | seq_logits = seq_logits[seq_mask].unsqueeze(0) |
| | seq_labels = seq_labels[seq_mask].unsqueeze(0) |
| |
|
| | if seq_logits.numel() != 0: |
| | loss -= self.crf(seq_logits, seq_labels, |
| | reduction='token_mean') |
| |
|
| | return loss / batch_size |
| | else: |
| | for seq_logits, seq_mask in zip(logits, labels_mask): |
| | seq_logits = seq_logits[seq_mask].unsqueeze(0) |
| |
|
| | if seq_logits.numel() != 0: |
| | tags = self.crf.decode(seq_logits) |
| | else: |
| | tags = [[]] |
| |
|
| | |
| | output_tags.append(tags[0]) |
| |
|
| | return output_tags |
| |
|
| |
|
| | class ModelRegisterStep(): |
| | def __call__(self, args): |
| |
|
| | AutoConfig.register("BERT_CRF", BERT_CRF_Config) |
| | AutoModelForTokenClassification.register(BERT_CRF_Config, BERT_CRF) |
| |
|
| | return { |
| | **args, |
| | } |
| |
|