| import torch
|
| import torch.nn as nn
|
| from transformers import AutoModel, AutoConfig
|
| from torchcrf import CRF
|
|
|
| class BertCrfTokenClassification(nn.Module):
|
| def __init__(self, base_model_id: str, num_labels: int):
|
| super().__init__()
|
| self.num_labels = num_labels
|
|
|
|
|
| self.config = AutoConfig.from_pretrained(base_model_id)
|
| self.roberta = AutoModel.from_pretrained(base_model_id, config=self.config)
|
|
|
|
|
| self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
|
| self.classifier = nn.Linear(self.config.hidden_size, num_labels)
|
|
|
|
|
| self.crf = CRF(num_tags=num_labels, batch_first=True)
|
|
|
| def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
| outputs = self.roberta(input_ids, attention_mask=attention_mask, **kwargs)
|
| sequence_output = self.dropout(outputs[0])
|
| emissions = self.classifier(sequence_output)
|
|
|
|
|
| base_mask = attention_mask.bool() if attention_mask is not None else None
|
|
|
| if labels is not None:
|
|
|
| safe_labels = torch.where(labels == -100, torch.tensor(0, device=labels.device), labels)
|
|
|
|
|
|
|
|
|
|
|
| loss = -self.crf(emissions, tags=safe_labels, mask=base_mask, reduction='mean')
|
|
|
|
|
| preds = self.crf.decode(emissions, mask=base_mask)
|
| seq_len = emissions.size(1)
|
| padded_preds = [p + [0] * (seq_len - len(p)) for p in preds]
|
| pred_tensor = torch.tensor(padded_preds, device=emissions.device)
|
|
|
| return {"loss": loss, "logits": pred_tensor}
|
| else:
|
|
|
| preds = self.crf.decode(emissions, mask=base_mask)
|
| seq_len = emissions.size(1)
|
| padded_preds = [p + [0] * (seq_len - len(p)) for p in preds]
|
| pred_tensor = torch.tensor(padded_preds, device=emissions.device)
|
|
|
| return {"logits": pred_tensor} |