import torch import torch.nn as nn from transformers import AutoModel, AutoConfig from torchcrf import CRF class BertCrfTokenClassification(nn.Module): def __init__(self, base_model_id: str, num_labels: int): super().__init__() self.num_labels = num_labels # Load the configuration and the base RoBERTa model self.config = AutoConfig.from_pretrained(base_model_id) self.roberta = AutoModel.from_pretrained(base_model_id, config=self.config) # Linear layer to map hidden states to tag space (emissions) self.dropout = nn.Dropout(self.config.hidden_dropout_prob) self.classifier = nn.Linear(self.config.hidden_size, num_labels) # CRF layer for transition probabilities and Viterbi decoding self.crf = CRF(num_tags=num_labels, batch_first=True) def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): outputs = self.roberta(input_ids, attention_mask=attention_mask, **kwargs) sequence_output = self.dropout(outputs[0]) emissions = self.classifier(sequence_output) # Base mask for padding base_mask = attention_mask.bool() if attention_mask is not None else None if labels is not None: # 1. PyTorch-CRF cannot handle -100 labels. Replace them with 0 (O tag). safe_labels = torch.where(labels == -100, torch.tensor(0, device=labels.device), labels) # 2. Create a strict mask that ignores padding AND subword tokens (-100) # crf_mask = base_mask & (labels != -100) if base_mask is not None else (labels != -100) # 3. Calculate NLL Loss loss = -self.crf(emissions, tags=safe_labels, mask=base_mask, reduction='mean') # 4. Decode for metric evaluation during training preds = self.crf.decode(emissions, mask=base_mask) seq_len = emissions.size(1) padded_preds = [p + [0] * (seq_len - len(p)) for p in preds] pred_tensor = torch.tensor(padded_preds, device=emissions.device) return {"loss": loss, "logits": pred_tensor} else: # Inference Mode (No labels provided) preds = self.crf.decode(emissions, mask=base_mask) seq_len = emissions.size(1) padded_preds = [p + [0] * (seq_len - len(p)) for p in preds] pred_tensor = torch.tensor(padded_preds, device=emissions.device) return {"logits": pred_tensor}