|
|
from transformers import AutoModelForTokenClassification, AutoModel, AutoConfig |
|
|
from transformers.modeling_outputs import TokenClassifierOutput |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchcrf import CRF |
|
|
from typing import Optional, Union, Tuple, List |
|
|
import os |
|
|
import json |
|
|
|
|
|
|
|
|
class TransformerCRFForTokenClassification(AutoModelForTokenClassification): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.num_labels = config.num_labels |
|
|
|
|
|
self.base_model = AutoModel.from_config(config=config, use_safetensors=True) |
|
|
hidden_size = config.hidden_size if hasattr(config, 'hidden_size') else 768 |
|
|
|
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob if hasattr(config, 'hidden_dropout_prob') else 0.1) |
|
|
self.classifier = nn.Linear(hidden_size, config.num_labels) |
|
|
|
|
|
|
|
|
self.use_crf = config.use_crf if hasattr(config, 'use_crf') else False |
|
|
if self.use_crf: |
|
|
self.crf = CRF(num_tags=self.num_labels, batch_first=True) |
|
|
else: |
|
|
self.crf = None |
|
|
self.loss_fn = nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.base_model( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
sequence_output = outputs[0] |
|
|
sequence_output = self.dropout(sequence_output) |
|
|
logits = self.classifier(sequence_output) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
if self.crf is not None: |
|
|
mask = attention_mask.bool() |
|
|
labels_mask = labels != -100 |
|
|
mask = mask & labels_mask |
|
|
loss = -self.crf(logits, labels, mask=mask, reduction='mean') |
|
|
else: |
|
|
loss = self.loss_fn(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[2:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return TokenClassifierOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=outputs.hidden_states if output_hidden_states else None, |
|
|
attentions=outputs.attentions if output_attentions else None, |
|
|
) |
|
|
|
|
|
def save_pretrained(self, save_directory: str, **kwargs): |
|
|
"""Save model with custom CRF layer""" |
|
|
|
|
|
self.config.use_crf = self.use_crf |
|
|
self.config.save_pretrained(save_directory, safe_serialization=True) |
|
|
|
|
|
|
|
|
super().save_pretrained(save_directory, safe_serialization=True, **kwargs) |
|
|
|
|
|
if self.crf is not None: |
|
|
crf_path = os.path.join(save_directory, "crf.pt") |
|
|
torch.save(self.crf.state_dict(), crf_path) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): |
|
|
"""Load model with custom CRF layer""" |
|
|
if 'config' in kwargs: |
|
|
config = kwargs.pop('config') |
|
|
else: |
|
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path) |
|
|
|
|
|
|
|
|
if not hasattr(config, 'use_crf'): |
|
|
config.use_crf = False |
|
|
|
|
|
|
|
|
model = super().from_pretrained(pretrained_model_name_or_path, config=config, use_safetensors=True, *model_args, **kwargs) |
|
|
|
|
|
|
|
|
if config.use_crf: |
|
|
model.crf = CRF(num_tags=config.num_labels, batch_first=True) |
|
|
crf_path = os.path.join(pretrained_model_name_or_path, "crf.pt") |
|
|
if os.path.exists(crf_path): |
|
|
model.crf.load_state_dict(torch.load(crf_path)) |
|
|
else: |
|
|
model.crf = None |
|
|
|
|
|
return model |
|
|
|