|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchcrf import CRF |
|
|
|
|
|
class BERT_BiLSTM_CRF(nn.Module): |
|
|
def __init__(self, base_model, config, dropout_rate=0.2, rnn_dim=256): |
|
|
super().__init__() |
|
|
self.bert = base_model |
|
|
self.label2id = config.label2id |
|
|
self.id2label = config.id2label |
|
|
self.num_labels = config.num_labels |
|
|
|
|
|
self.bilstm = nn.LSTM( |
|
|
self.bert.config.hidden_size, |
|
|
rnn_dim, |
|
|
num_layers=2, |
|
|
batch_first=True, |
|
|
bidirectional=True, |
|
|
dropout=0.2 |
|
|
) |
|
|
self.dropout = nn.Dropout(dropout_rate) |
|
|
self.classifier = nn.Linear(rnn_dim * 2, self.num_labels) |
|
|
self.crf = CRF(self.num_labels, batch_first=True) |
|
|
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None): |
|
|
outputs = self.bert( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids |
|
|
) |
|
|
lstm_out, _ = self.bilstm(self.dropout(outputs.last_hidden_state)) |
|
|
emissions = self.classifier(lstm_out) |
|
|
mask = attention_mask.bool() |
|
|
|
|
|
if labels is not None: |
|
|
safe_labels = labels.clone() |
|
|
safe_labels[labels == -100] = self.label2id['O'] |
|
|
loss = -self.crf(emissions, safe_labels, mask=mask, reduction='mean') |
|
|
return {'loss': loss, 'logits': emissions} |
|
|
else: |
|
|
decoded = self.crf.decode(emissions, mask=mask) |
|
|
max_len = input_ids.shape[1] |
|
|
padded_decoded = [seq + [0] * (max_len - len(seq)) for seq in decoded] |
|
|
logits = torch.tensor(padded_decoded, device=input_ids.device) |
|
|
return {'logits': logits} |
|
|
|