| from transformers import AutoModel |
| from huggingface_hub import PyTorchModelHubMixin |
| from torchcrf import CRF |
| import torch.nn as nn |
|
|
| class CRFTokenClassificationModel(nn.Module, PyTorchModelHubMixin): |
| def __init__(self, config): |
| super().__init__() |
| |
| self.transformer = AutoModel.from_pretrained(config["pretrained_model_name"]) |
| |
| |
| self.dropout = nn.Dropout(config["hidden_dropout_prob"]) |
| self.classifier = nn.Linear(config["hidden_size"], config["num_labels"]) |
| |
| |
| self.crf = CRF(config["num_labels"], batch_first=True) |
| |
| |
| self.id2label = {v: k for k, v in config["label_map"].items()} |
|
|
| def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): |
| |
| outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask) |
| sequence_output = self.dropout(outputs.last_hidden_state) |
| logits = self.classifier(sequence_output) |
|
|
| |
| if labels is not None: |
| loss = -self.crf(logits, labels, mask=attention_mask.bool(), reduction='mean') |
| return {"loss": loss, "logits": logits} |
| |
| |
| return {"logits": logits} |
|
|
| def decode(self, logits, mask): |
| |
| return self.crf.decode(logits, mask.bool()) |