| """Custom multi-label token classifier for HuggingFace Hub.""" |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import BertModel, BertPreTrainedModel |
|
|
|
|
| class MultiLabelCRF(nn.Module): |
| """Independent CRF per marker type for multi-label BIO tagging.""" |
|
|
| def __init__(self, num_types: int) -> None: |
| super().__init__() |
| self.num_types = num_types |
| self.transitions = nn.Parameter(torch.empty(num_types, 3, 3)) |
| self.start_transitions = nn.Parameter(torch.empty(num_types, 3)) |
| self.end_transitions = nn.Parameter(torch.empty(num_types, 3)) |
| self._reset_parameters() |
|
|
| def _reset_parameters(self) -> None: |
| nn.init.uniform_(self.transitions, -0.1, 0.1) |
| nn.init.uniform_(self.start_transitions, -0.1, 0.1) |
| nn.init.uniform_(self.end_transitions, -0.1, 0.1) |
| with torch.no_grad(): |
| self.transitions.data[:, 0, 2] = -10000.0 |
| self.start_transitions.data[:, 2] = -10000.0 |
|
|
| def decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| """Viterbi decoding. |
| |
| Args: |
| emissions: (batch, seq, num_types, 3) |
| mask: (batch, seq) boolean |
| |
| Returns: (batch, seq, num_types) best tag sequences |
| """ |
| batch, seq, num_types, _ = emissions.shape |
|
|
| |
| em = emissions.permute(0, 2, 1, 3).reshape(batch * num_types, seq, 3) |
| mk = mask.unsqueeze(1).expand(-1, num_types, -1).reshape(batch * num_types, seq) |
|
|
| BT = batch * num_types |
|
|
| |
| trans = ( |
| self.transitions.unsqueeze(0).expand(batch, -1, -1, -1).reshape(BT, 3, 3) |
| ) |
| start = self.start_transitions.unsqueeze(0).expand(batch, -1, -1).reshape(BT, 3) |
| end = self.end_transitions.unsqueeze(0).expand(batch, -1, -1).reshape(BT, 3) |
|
|
| arange = torch.arange(BT, device=em.device) |
| score = start + em[:, 0] |
| history: list[torch.Tensor] = [] |
|
|
| for i in range(1, seq): |
| broadcast = score.unsqueeze(2) + trans + em[:, i].unsqueeze(1) |
| best_score, best_prev = broadcast.max(dim=1) |
| score = torch.where(mk[:, i].unsqueeze(1), best_score, score) |
| history.append(best_prev) |
|
|
| score = score + end |
| _, best_last = score.max(dim=1) |
|
|
| best_paths = torch.zeros(BT, seq, dtype=torch.long, device=em.device) |
| seq_lengths = mk.sum(dim=1).long() |
| best_paths[arange, seq_lengths - 1] = best_last |
|
|
| for i in range(seq - 2, -1, -1): |
| prev_tag = history[i][arange, best_paths[:, i + 1]] |
| should_update = i < (seq_lengths - 1) |
| best_paths[:, i] = torch.where(should_update, prev_tag, best_paths[:, i]) |
|
|
| return best_paths.reshape(batch, num_types, seq).permute(0, 2, 1) |
|
|
|
|
| class HavelockTokenClassifier(BertPreTrainedModel): |
| """Multi-label BIO token classifier with independent O/B/I heads per marker type. |
| |
| Each token gets num_types independent 3-way classifications, allowing |
| overlapping spans (e.g. a token simultaneously B-anaphora and I-concessive). |
| |
| Output logits shape: (batch, seq_len, num_types, 3) |
| """ |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.num_types = config.num_types |
| self.use_crf = getattr(config, "use_crf", False) |
| self.bert = BertModel(config, add_pooling_layer=False) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| self.classifier = nn.Linear(config.hidden_size, config.num_types * 3) |
|
|
| if self.use_crf: |
| self.crf = MultiLabelCRF(config.num_types) |
|
|
| self.post_init() |
|
|
| def forward(self, input_ids, attention_mask=None, **kwargs): |
| hidden = self.bert( |
| input_ids=input_ids, attention_mask=attention_mask |
| ).last_hidden_state |
| hidden = self.dropout(hidden) |
| logits = self.classifier(hidden) |
| batch, seq, _ = logits.shape |
| logits = logits.view(batch, seq, self.num_types, 3) |
|
|
| |
| |
| if self.use_crf and not self.training: |
| mask = ( |
| attention_mask.bool() |
| if attention_mask is not None |
| else torch.ones(batch, seq, dtype=torch.bool, device=logits.device) |
| ) |
| |
| |
| pass |
|
|
| return logits |
|
|
| def decode(self, input_ids, attention_mask=None): |
| """Run forward pass and return Viterbi-decoded tags.""" |
| logits = self.forward(input_ids, attention_mask) |
| if self.use_crf: |
| mask = ( |
| attention_mask.bool() |
| if attention_mask is not None |
| else torch.ones( |
| logits.shape[:2], dtype=torch.bool, device=logits.device |
| ) |
| ) |
| return self.crf.decode(logits, mask) |
| return logits.argmax(dim=-1) |
|
|