bert-token-classifier / modeling_havelock.py
permutans's picture
Upload folder using huggingface_hub
03a30b6 verified
"""Custom multi-label token classifier — backbone-agnostic."""
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
class MultiLabelCRF(nn.Module):
"""Independent CRF per marker type for multi-label BIO tagging."""
def __init__(self, num_types: int) -> None:
super().__init__()
self.num_types = num_types
self.transitions = nn.Parameter(torch.empty(num_types, 3, 3))
self.start_transitions = nn.Parameter(torch.empty(num_types, 3))
self.end_transitions = nn.Parameter(torch.empty(num_types, 3))
# Placeholder — will be overwritten by loaded weights if present
self.register_buffer("emission_bias", torch.zeros(1, 1, 1, 3))
self._reset_parameters()
def _reset_parameters(self) -> None:
nn.init.uniform_(self.transitions, -0.1, 0.1)
nn.init.uniform_(self.start_transitions, -0.1, 0.1)
nn.init.uniform_(self.end_transitions, -0.1, 0.1)
with torch.no_grad():
self.transitions.data[:, 0, 2] = -10000.0
self.start_transitions.data[:, 2] = -10000.0
def _apply_emission_bias(self, emissions: torch.Tensor) -> torch.Tensor:
if self.emission_bias is not None:
return emissions + self.emission_bias
return emissions
def decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Viterbi decoding.
Args:
emissions: (batch, seq, num_types, 3)
mask: (batch, seq) boolean
Returns: (batch, seq, num_types) best tag sequences
"""
# Apply emission bias before decoding
emissions = self._apply_emission_bias(emissions)
batch, seq, num_types, _ = emissions.shape
# Reshape to (batch*num_types, seq, 3)
em = emissions.permute(0, 2, 1, 3).reshape(batch * num_types, seq, 3)
mk = mask.unsqueeze(1).expand(-1, num_types, -1).reshape(batch * num_types, seq)
BT = batch * num_types
# Expand params across batch
trans = (
self.transitions.unsqueeze(0).expand(batch, -1, -1, -1).reshape(BT, 3, 3)
)
start = self.start_transitions.unsqueeze(0).expand(batch, -1, -1).reshape(BT, 3)
end = self.end_transitions.unsqueeze(0).expand(batch, -1, -1).reshape(BT, 3)
arange = torch.arange(BT, device=em.device)
score = start + em[:, 0]
history: list[torch.Tensor] = []
for i in range(1, seq):
broadcast = score.unsqueeze(2) + trans + em[:, i].unsqueeze(1)
best_score, best_prev = broadcast.max(dim=1)
score = torch.where(mk[:, i].unsqueeze(1), best_score, score)
history.append(best_prev)
score = score + end
_, best_last = score.max(dim=1)
best_paths = torch.zeros(BT, seq, dtype=torch.long, device=em.device)
seq_lengths = mk.sum(dim=1).long()
best_paths[arange, seq_lengths - 1] = best_last
for i in range(seq - 2, -1, -1):
prev_tag = history[i][arange, best_paths[:, i + 1]]
should_update = i < (seq_lengths - 1)
best_paths[:, i] = torch.where(should_update, prev_tag, best_paths[:, i])
return best_paths.reshape(batch, num_types, seq).permute(0, 2, 1)
class HavelockTokenConfig(PretrainedConfig):
"""Config that wraps any backbone config + our custom fields."""
model_type = "havelock_token_classifier"
def __init__(self, num_types: int = 1, use_crf: bool = False, **kwargs):
super().__init__(**kwargs)
self.num_types = num_types
self.use_crf = use_crf
class HavelockTokenClassifier(PreTrainedModel):
config_class = HavelockTokenConfig
def __init__(
self, config: HavelockTokenConfig, backbone: PreTrainedModel | None = None
):
super().__init__(config)
self.num_types = config.num_types
self.use_crf = config.use_crf
# Accept injected backbone (from_pretrained path) or build from config
if backbone is not None:
self.bert = backbone
else:
self.bert = AutoModel.from_config(config)
self.dropout = nn.Dropout(getattr(config, "hidden_dropout_prob", 0.1))
self.classifier = nn.Linear(config.hidden_size, config.num_types * 3)
if self.use_crf:
self.crf = MultiLabelCRF(config.num_types)
self.post_init()
@classmethod
def from_backbone(
cls,
model_name: str,
num_types: int,
use_crf: bool = False,
obi_bias: torch.Tensor | None = None,
) -> "HavelockTokenClassifier":
"""Build from a pretrained backbone name — the training entrypoint."""
backbone = AutoModel.from_pretrained(model_name)
backbone_config = backbone.config
config = HavelockTokenConfig(
num_types=num_types,
use_crf=use_crf,
**backbone_config.to_dict(),
)
model = cls(config, backbone=backbone)
if use_crf and obi_bias is not None:
model.crf.emission_bias = obi_bias.reshape(1, 1, 1, 3)
return model
def forward(self, input_ids, attention_mask=None, **kwargs):
hidden = self.bert(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state
hidden = self.dropout(hidden)
logits = self.classifier(hidden)
batch, seq, _ = logits.shape
return logits.view(batch, seq, self.num_types, 3)
def decode(self, input_ids, attention_mask=None):
logits = self.forward(input_ids, attention_mask)
if self.use_crf:
mask = (
attention_mask.bool()
if attention_mask is not None
else torch.ones(
logits.shape[:2], dtype=torch.bool, device=logits.device
)
)
return self.crf.decode(logits, mask)
return logits.argmax(dim=-1)