| | """Custom multi-label token classifier — backbone-agnostic.""" |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel |
| |
|
| |
|
| | class MultiLabelCRF(nn.Module): |
| | """Independent CRF per marker type for multi-label BIO tagging.""" |
| |
|
| | def __init__(self, num_types: int) -> None: |
| | super().__init__() |
| | self.num_types = num_types |
| | self.transitions = nn.Parameter(torch.empty(num_types, 3, 3)) |
| | self.start_transitions = nn.Parameter(torch.empty(num_types, 3)) |
| | self.end_transitions = nn.Parameter(torch.empty(num_types, 3)) |
| | |
| | self.register_buffer("emission_bias", torch.zeros(1, 1, 1, 3)) |
| | self._reset_parameters() |
| |
|
| | def _reset_parameters(self) -> None: |
| | nn.init.uniform_(self.transitions, -0.1, 0.1) |
| | nn.init.uniform_(self.start_transitions, -0.1, 0.1) |
| | nn.init.uniform_(self.end_transitions, -0.1, 0.1) |
| | with torch.no_grad(): |
| | self.transitions.data[:, 0, 2] = -10000.0 |
| | self.start_transitions.data[:, 2] = -10000.0 |
| |
|
| | def _apply_emission_bias(self, emissions: torch.Tensor) -> torch.Tensor: |
| | if self.emission_bias is not None: |
| | return emissions + self.emission_bias |
| |
|
| | return emissions |
| |
|
| | def decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| | """Viterbi decoding. |
| | |
| | Args: |
| | emissions: (batch, seq, num_types, 3) |
| | mask: (batch, seq) boolean |
| | |
| | Returns: (batch, seq, num_types) best tag sequences |
| | """ |
| | |
| | emissions = self._apply_emission_bias(emissions) |
| | batch, seq, num_types, _ = emissions.shape |
| |
|
| | |
| | em = emissions.permute(0, 2, 1, 3).reshape(batch * num_types, seq, 3) |
| | mk = mask.unsqueeze(1).expand(-1, num_types, -1).reshape(batch * num_types, seq) |
| | BT = batch * num_types |
| |
|
| | |
| | trans = ( |
| | self.transitions.unsqueeze(0).expand(batch, -1, -1, -1).reshape(BT, 3, 3) |
| | ) |
| |
|
| | start = self.start_transitions.unsqueeze(0).expand(batch, -1, -1).reshape(BT, 3) |
| | end = self.end_transitions.unsqueeze(0).expand(batch, -1, -1).reshape(BT, 3) |
| |
|
| | arange = torch.arange(BT, device=em.device) |
| | score = start + em[:, 0] |
| | history: list[torch.Tensor] = [] |
| |
|
| | for i in range(1, seq): |
| | broadcast = score.unsqueeze(2) + trans + em[:, i].unsqueeze(1) |
| | best_score, best_prev = broadcast.max(dim=1) |
| | score = torch.where(mk[:, i].unsqueeze(1), best_score, score) |
| | history.append(best_prev) |
| |
|
| | score = score + end |
| | _, best_last = score.max(dim=1) |
| |
|
| | best_paths = torch.zeros(BT, seq, dtype=torch.long, device=em.device) |
| | seq_lengths = mk.sum(dim=1).long() |
| | best_paths[arange, seq_lengths - 1] = best_last |
| |
|
| | for i in range(seq - 2, -1, -1): |
| | prev_tag = history[i][arange, best_paths[:, i + 1]] |
| | should_update = i < (seq_lengths - 1) |
| | best_paths[:, i] = torch.where(should_update, prev_tag, best_paths[:, i]) |
| |
|
| | return best_paths.reshape(batch, num_types, seq).permute(0, 2, 1) |
| |
|
| |
|
| | class HavelockTokenConfig(PretrainedConfig): |
| | """Config that wraps any backbone config + our custom fields.""" |
| |
|
| | model_type = "havelock_token_classifier" |
| |
|
| | def __init__(self, num_types: int = 1, use_crf: bool = False, **kwargs): |
| | super().__init__(**kwargs) |
| | self.num_types = num_types |
| | self.use_crf = use_crf |
| |
|
| |
|
| | class HavelockTokenClassifier(PreTrainedModel): |
| | config_class = HavelockTokenConfig |
| |
|
| | def __init__( |
| | self, config: HavelockTokenConfig, backbone: PreTrainedModel | None = None |
| | ): |
| | super().__init__(config) |
| | self.num_types = config.num_types |
| | self.use_crf = config.use_crf |
| |
|
| | |
| | if backbone is not None: |
| | self.bert = backbone |
| | else: |
| | self.bert = AutoModel.from_config(config) |
| |
|
| | self.dropout = nn.Dropout(getattr(config, "hidden_dropout_prob", 0.1)) |
| | self.classifier = nn.Linear(config.hidden_size, config.num_types * 3) |
| |
|
| | if self.use_crf: |
| | self.crf = MultiLabelCRF(config.num_types) |
| |
|
| | self.post_init() |
| |
|
| | @classmethod |
| | def from_backbone( |
| | cls, |
| | model_name: str, |
| | num_types: int, |
| | use_crf: bool = False, |
| | obi_bias: torch.Tensor | None = None, |
| | ) -> "HavelockTokenClassifier": |
| | """Build from a pretrained backbone name — the training entrypoint.""" |
| | backbone = AutoModel.from_pretrained(model_name) |
| | backbone_config = backbone.config |
| |
|
| | config = HavelockTokenConfig( |
| | num_types=num_types, |
| | use_crf=use_crf, |
| | **backbone_config.to_dict(), |
| | ) |
| | model = cls(config, backbone=backbone) |
| | if use_crf and obi_bias is not None: |
| | model.crf.emission_bias = obi_bias.reshape(1, 1, 1, 3) |
| | return model |
| |
|
| | def forward(self, input_ids, attention_mask=None, **kwargs): |
| | hidden = self.bert( |
| | input_ids=input_ids, attention_mask=attention_mask |
| | ).last_hidden_state |
| | hidden = self.dropout(hidden) |
| | logits = self.classifier(hidden) |
| | batch, seq, _ = logits.shape |
| | return logits.view(batch, seq, self.num_types, 3) |
| |
|
| | def decode(self, input_ids, attention_mask=None): |
| | logits = self.forward(input_ids, attention_mask) |
| | if self.use_crf: |
| | mask = ( |
| | attention_mask.bool() |
| | if attention_mask is not None |
| | else torch.ones( |
| | logits.shape[:2], dtype=torch.bool, device=logits.device |
| | ) |
| | ) |
| | return self.crf.decode(logits, mask) |
| | return logits.argmax(dim=-1) |
| |
|