Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from torchcrf import CRF | |
| class BERT_BiLSTM_CRF(nn.Module): | |
| def __init__(self, base_model, num_labels, dropout_rate=0.2, rnn_dim=256): | |
| super().__init__() | |
| self.bert = base_model | |
| self.bilstm = nn.LSTM( | |
| self.bert.config.hidden_size, | |
| rnn_dim, | |
| num_layers=2, | |
| batch_first=True, | |
| bidirectional=True, | |
| dropout=0.2 | |
| ) | |
| self.dropout = nn.Dropout(dropout_rate) | |
| self.classifier = nn.Linear(rnn_dim * 2, num_labels) | |
| self.crf = CRF(num_labels, batch_first=True) | |
| def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None): | |
| outputs = self.bert( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids | |
| ) | |
| lstm_out, _ = self.bilstm(self.dropout(outputs.last_hidden_state)) | |
| emissions = self.classifier(lstm_out) | |
| mask = attention_mask.bool() | |
| if labels is not None: | |
| safe_labels = labels.clone() | |
| safe_labels[labels == -100] = 0 # Default to "O" index | |
| loss = -self.crf(emissions, safe_labels, mask=mask, reduction='mean') | |
| return {'loss': loss, 'logits': emissions} | |
| else: | |
| decoded = self.crf.decode(emissions, mask=mask) | |
| max_len = input_ids.shape[1] | |
| padded_decoded = [seq + [0] * (max_len - len(seq)) for seq in decoded] | |
| logits = torch.tensor(padded_decoded, device=input_ids.device) | |
| return {'logits': logits} | |