"""Custom multi-label token classifier — backbone-agnostic.""" import torch import torch.nn as nn from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel class MultiLabelCRF(nn.Module): """Independent CRF per marker type for multi-label BIO tagging.""" def __init__(self, num_types: int) -> None: super().__init__() self.num_types = num_types self.transitions = nn.Parameter(torch.empty(num_types, 3, 3)) self.start_transitions = nn.Parameter(torch.empty(num_types, 3)) self.end_transitions = nn.Parameter(torch.empty(num_types, 3)) # Placeholder — will be overwritten by loaded weights if present self.register_buffer("emission_bias", torch.zeros(1, 1, 1, 3)) self._reset_parameters() def _reset_parameters(self) -> None: nn.init.uniform_(self.transitions, -0.1, 0.1) nn.init.uniform_(self.start_transitions, -0.1, 0.1) nn.init.uniform_(self.end_transitions, -0.1, 0.1) with torch.no_grad(): self.transitions.data[:, 0, 2] = -10000.0 self.start_transitions.data[:, 2] = -10000.0 def _apply_emission_bias(self, emissions: torch.Tensor) -> torch.Tensor: if self.emission_bias is not None: return emissions + self.emission_bias return emissions def decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """Viterbi decoding. Args: emissions: (batch, seq, num_types, 3) mask: (batch, seq) boolean Returns: (batch, seq, num_types) best tag sequences """ # Apply emission bias before decoding emissions = self._apply_emission_bias(emissions) batch, seq, num_types, _ = emissions.shape # Reshape to (batch*num_types, seq, 3) em = emissions.permute(0, 2, 1, 3).reshape(batch * num_types, seq, 3) mk = mask.unsqueeze(1).expand(-1, num_types, -1).reshape(batch * num_types, seq) BT = batch * num_types # Expand params across batch trans = ( self.transitions.unsqueeze(0).expand(batch, -1, -1, -1).reshape(BT, 3, 3) ) start = self.start_transitions.unsqueeze(0).expand(batch, -1, -1).reshape(BT, 3) end = self.end_transitions.unsqueeze(0).expand(batch, -1, -1).reshape(BT, 3) arange = torch.arange(BT, device=em.device) score = start + em[:, 0] history: list[torch.Tensor] = [] for i in range(1, seq): broadcast = score.unsqueeze(2) + trans + em[:, i].unsqueeze(1) best_score, best_prev = broadcast.max(dim=1) score = torch.where(mk[:, i].unsqueeze(1), best_score, score) history.append(best_prev) score = score + end _, best_last = score.max(dim=1) best_paths = torch.zeros(BT, seq, dtype=torch.long, device=em.device) seq_lengths = mk.sum(dim=1).long() best_paths[arange, seq_lengths - 1] = best_last for i in range(seq - 2, -1, -1): prev_tag = history[i][arange, best_paths[:, i + 1]] should_update = i < (seq_lengths - 1) best_paths[:, i] = torch.where(should_update, prev_tag, best_paths[:, i]) return best_paths.reshape(batch, num_types, seq).permute(0, 2, 1) class HavelockTokenConfig(PretrainedConfig): """Config that wraps any backbone config + our custom fields.""" model_type = "havelock_token_classifier" def __init__(self, num_types: int = 1, use_crf: bool = False, **kwargs): super().__init__(**kwargs) self.num_types = num_types self.use_crf = use_crf class HavelockTokenClassifier(PreTrainedModel): config_class = HavelockTokenConfig def __init__( self, config: HavelockTokenConfig, backbone: PreTrainedModel | None = None ): super().__init__(config) self.num_types = config.num_types self.use_crf = config.use_crf # Accept injected backbone (from_pretrained path) or build from config if backbone is not None: self.bert = backbone else: self.bert = AutoModel.from_config(config) self.dropout = nn.Dropout(getattr(config, "hidden_dropout_prob", 0.1)) self.classifier = nn.Linear(config.hidden_size, config.num_types * 3) if self.use_crf: self.crf = MultiLabelCRF(config.num_types) self.post_init() @classmethod def from_backbone( cls, model_name: str, num_types: int, use_crf: bool = False, obi_bias: torch.Tensor | None = None, ) -> "HavelockTokenClassifier": """Build from a pretrained backbone name — the training entrypoint.""" backbone = AutoModel.from_pretrained(model_name) backbone_config = backbone.config config = HavelockTokenConfig( num_types=num_types, use_crf=use_crf, **backbone_config.to_dict(), ) model = cls(config, backbone=backbone) if use_crf and obi_bias is not None: model.crf.emission_bias = obi_bias.reshape(1, 1, 1, 3) return model def forward(self, input_ids, attention_mask=None, **kwargs): hidden = self.bert( input_ids=input_ids, attention_mask=attention_mask ).last_hidden_state hidden = self.dropout(hidden) logits = self.classifier(hidden) batch, seq, _ = logits.shape return logits.view(batch, seq, self.num_types, 3) def decode(self, input_ids, attention_mask=None): logits = self.forward(input_ids, attention_mask) if self.use_crf: mask = ( attention_mask.bool() if attention_mask is not None else torch.ones( logits.shape[:2], dtype=torch.bool, device=logits.device ) ) return self.crf.decode(logits, mask) return logits.argmax(dim=-1)