File size: 6,017 Bytes
47ff542 17f1925 e7adbd7 17f1925 621c79f 17f1925 8019718 47ff542 621c79f 47ff542 e7adbd7 47ff542 e7adbd7 47ff542 e7adbd7 47ff542 e7adbd7 621c79f 17f1925 47ff542 03a30b6 47ff542 03a30b6 47ff542 621c79f 17f1925 e7adbd7 d2edaf0 47ff542 17f1925 03a30b6 17f1925 47ff542 e7adbd7 621c79f e7adbd7 621c79f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | """Custom multi-label token classifier — backbone-agnostic."""
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
class MultiLabelCRF(nn.Module):
"""Independent CRF per marker type for multi-label BIO tagging."""
def __init__(self, num_types: int) -> None:
super().__init__()
self.num_types = num_types
self.transitions = nn.Parameter(torch.empty(num_types, 3, 3))
self.start_transitions = nn.Parameter(torch.empty(num_types, 3))
self.end_transitions = nn.Parameter(torch.empty(num_types, 3))
# Placeholder — will be overwritten by loaded weights if present
self.register_buffer("emission_bias", torch.zeros(1, 1, 1, 3))
self._reset_parameters()
def _reset_parameters(self) -> None:
nn.init.uniform_(self.transitions, -0.1, 0.1)
nn.init.uniform_(self.start_transitions, -0.1, 0.1)
nn.init.uniform_(self.end_transitions, -0.1, 0.1)
with torch.no_grad():
self.transitions.data[:, 0, 2] = -10000.0
self.start_transitions.data[:, 2] = -10000.0
def _apply_emission_bias(self, emissions: torch.Tensor) -> torch.Tensor:
if self.emission_bias is not None:
return emissions + self.emission_bias
return emissions
def decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Viterbi decoding.
Args:
emissions: (batch, seq, num_types, 3)
mask: (batch, seq) boolean
Returns: (batch, seq, num_types) best tag sequences
"""
# Apply emission bias before decoding
emissions = self._apply_emission_bias(emissions)
batch, seq, num_types, _ = emissions.shape
# Reshape to (batch*num_types, seq, 3)
em = emissions.permute(0, 2, 1, 3).reshape(batch * num_types, seq, 3)
mk = mask.unsqueeze(1).expand(-1, num_types, -1).reshape(batch * num_types, seq)
BT = batch * num_types
# Expand params across batch
trans = (
self.transitions.unsqueeze(0).expand(batch, -1, -1, -1).reshape(BT, 3, 3)
)
start = self.start_transitions.unsqueeze(0).expand(batch, -1, -1).reshape(BT, 3)
end = self.end_transitions.unsqueeze(0).expand(batch, -1, -1).reshape(BT, 3)
arange = torch.arange(BT, device=em.device)
score = start + em[:, 0]
history: list[torch.Tensor] = []
for i in range(1, seq):
broadcast = score.unsqueeze(2) + trans + em[:, i].unsqueeze(1)
best_score, best_prev = broadcast.max(dim=1)
score = torch.where(mk[:, i].unsqueeze(1), best_score, score)
history.append(best_prev)
score = score + end
_, best_last = score.max(dim=1)
best_paths = torch.zeros(BT, seq, dtype=torch.long, device=em.device)
seq_lengths = mk.sum(dim=1).long()
best_paths[arange, seq_lengths - 1] = best_last
for i in range(seq - 2, -1, -1):
prev_tag = history[i][arange, best_paths[:, i + 1]]
should_update = i < (seq_lengths - 1)
best_paths[:, i] = torch.where(should_update, prev_tag, best_paths[:, i])
return best_paths.reshape(batch, num_types, seq).permute(0, 2, 1)
class HavelockTokenConfig(PretrainedConfig):
"""Config that wraps any backbone config + our custom fields."""
model_type = "havelock_token_classifier"
def __init__(self, num_types: int = 1, use_crf: bool = False, **kwargs):
super().__init__(**kwargs)
self.num_types = num_types
self.use_crf = use_crf
class HavelockTokenClassifier(PreTrainedModel):
config_class = HavelockTokenConfig
def __init__(
self, config: HavelockTokenConfig, backbone: PreTrainedModel | None = None
):
super().__init__(config)
self.num_types = config.num_types
self.use_crf = config.use_crf
# Accept injected backbone (from_pretrained path) or build from config
if backbone is not None:
self.bert = backbone
else:
self.bert = AutoModel.from_config(config)
self.dropout = nn.Dropout(getattr(config, "hidden_dropout_prob", 0.1))
self.classifier = nn.Linear(config.hidden_size, config.num_types * 3)
if self.use_crf:
self.crf = MultiLabelCRF(config.num_types)
self.post_init()
@classmethod
def from_backbone(
cls,
model_name: str,
num_types: int,
use_crf: bool = False,
obi_bias: torch.Tensor | None = None,
) -> "HavelockTokenClassifier":
"""Build from a pretrained backbone name — the training entrypoint."""
backbone = AutoModel.from_pretrained(model_name)
backbone_config = backbone.config
config = HavelockTokenConfig(
num_types=num_types,
use_crf=use_crf,
**backbone_config.to_dict(),
)
model = cls(config, backbone=backbone)
if use_crf and obi_bias is not None:
model.crf.emission_bias = obi_bias.reshape(1, 1, 1, 3)
return model
def forward(self, input_ids, attention_mask=None, **kwargs):
hidden = self.bert(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state
hidden = self.dropout(hidden)
logits = self.classifier(hidden)
batch, seq, _ = logits.shape
return logits.view(batch, seq, self.num_types, 3)
def decode(self, input_ids, attention_mask=None):
logits = self.forward(input_ids, attention_mask)
if self.use_crf:
mask = (
attention_mask.bool()
if attention_mask is not None
else torch.ones(
logits.shape[:2], dtype=torch.bool, device=logits.device
)
)
return self.crf.decode(logits, mask)
return logits.argmax(dim=-1)
|