File size: 2,726 Bytes
cb2d481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig
from torchcrf import CRF

class BertCrfTokenClassification(nn.Module):
    def __init__(self, base_model_id: str, num_labels: int):
        super().__init__()
        self.num_labels = num_labels

        # Load the configuration and the base RoBERTa model
        self.config = AutoConfig.from_pretrained(base_model_id)
        self.roberta = AutoModel.from_pretrained(base_model_id, config=self.config)

        # Linear layer to map hidden states to tag space (emissions)
        self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
        self.classifier = nn.Linear(self.config.hidden_size, num_labels)

        # CRF layer for transition probabilities and Viterbi decoding
        self.crf = CRF(num_tags=num_labels, batch_first=True)

    def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
            outputs = self.roberta(input_ids, attention_mask=attention_mask, **kwargs)
            sequence_output = self.dropout(outputs[0])
            emissions = self.classifier(sequence_output)
            
            # Base mask for padding
            base_mask = attention_mask.bool() if attention_mask is not None else None

            if labels is not None:
                # 1. PyTorch-CRF cannot handle -100 labels. Replace them with 0 (O tag).
                safe_labels = torch.where(labels == -100, torch.tensor(0, device=labels.device), labels)
                
                # 2. Create a strict mask that ignores padding AND subword tokens (-100)
                # crf_mask = base_mask & (labels != -100) if base_mask is not None else (labels != -100)
                
                # 3. Calculate NLL Loss
                loss = -self.crf(emissions, tags=safe_labels, mask=base_mask, reduction='mean')
                
                # 4. Decode for metric evaluation during training
                preds = self.crf.decode(emissions, mask=base_mask)
                seq_len = emissions.size(1)
                padded_preds = [p + [0] * (seq_len - len(p)) for p in preds]
                pred_tensor = torch.tensor(padded_preds, device=emissions.device)
                
                return {"loss": loss, "logits": pred_tensor}
            else:
                # Inference Mode (No labels provided)
                preds = self.crf.decode(emissions, mask=base_mask)
                seq_len = emissions.size(1)
                padded_preds = [p + [0] * (seq_len - len(p)) for p in preds]
                pred_tensor = torch.tensor(padded_preds, device=emissions.device)
                
                return {"logits": pred_tensor}