permutans commited on
Commit
e7adbd7
·
verified ·
1 Parent(s): b2d93db

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_havelock.py +103 -1
modeling_havelock.py CHANGED
@@ -1,9 +1,78 @@
1
  """Custom multi-label token classifier for HuggingFace Hub."""
2
 
 
3
  import torch.nn as nn
4
  from transformers import BertModel, BertPreTrainedModel
5
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  class HavelockTokenClassifier(BertPreTrainedModel):
8
  """Multi-label BIO token classifier with independent O/B/I heads per marker type.
9
 
@@ -16,9 +85,14 @@ class HavelockTokenClassifier(BertPreTrainedModel):
16
  def __init__(self, config):
17
  super().__init__(config)
18
  self.num_types = config.num_types
 
19
  self.bert = BertModel(config, add_pooling_layer=False)
20
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
21
  self.classifier = nn.Linear(config.hidden_size, config.num_types * 3)
 
 
 
 
22
  self.post_init()
23
 
24
  def forward(self, input_ids, attention_mask=None, **kwargs):
@@ -28,4 +102,32 @@ class HavelockTokenClassifier(BertPreTrainedModel):
28
  hidden = self.dropout(hidden)
29
  logits = self.classifier(hidden)
30
  batch, seq, _ = logits.shape
31
- return logits.view(batch, seq, self.num_types, 3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Custom multi-label token classifier for HuggingFace Hub."""
2
 
3
+ import torch
4
  import torch.nn as nn
5
  from transformers import BertModel, BertPreTrainedModel
6
 
7
 
8
+ class MultiLabelCRF(nn.Module):
9
+ """Independent CRF per marker type for multi-label BIO tagging."""
10
+
11
+ def __init__(self, num_types: int) -> None:
12
+ super().__init__()
13
+ self.num_types = num_types
14
+ self.transitions = nn.Parameter(torch.empty(num_types, 3, 3))
15
+ self.start_transitions = nn.Parameter(torch.empty(num_types, 3))
16
+ self.end_transitions = nn.Parameter(torch.empty(num_types, 3))
17
+ self._reset_parameters()
18
+
19
+ def _reset_parameters(self) -> None:
20
+ nn.init.uniform_(self.transitions, -0.1, 0.1)
21
+ nn.init.uniform_(self.start_transitions, -0.1, 0.1)
22
+ nn.init.uniform_(self.end_transitions, -0.1, 0.1)
23
+ with torch.no_grad():
24
+ self.transitions.data[:, 0, 2] = -10000.0
25
+ self.start_transitions.data[:, 2] = -10000.0
26
+
27
+ def decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
28
+ """Viterbi decoding.
29
+
30
+ Args:
31
+ emissions: (batch, seq, num_types, 3)
32
+ mask: (batch, seq) boolean
33
+
34
+ Returns: (batch, seq, num_types) best tag sequences
35
+ """
36
+ batch, seq, num_types, _ = emissions.shape
37
+
38
+ # Reshape to (batch*num_types, seq, 3)
39
+ em = emissions.permute(0, 2, 1, 3).reshape(batch * num_types, seq, 3)
40
+ mk = mask.unsqueeze(1).expand(-1, num_types, -1).reshape(batch * num_types, seq)
41
+
42
+ BT = batch * num_types
43
+
44
+ # Expand params across batch
45
+ trans = (
46
+ self.transitions.unsqueeze(0).expand(batch, -1, -1, -1).reshape(BT, 3, 3)
47
+ )
48
+ start = self.start_transitions.unsqueeze(0).expand(batch, -1, -1).reshape(BT, 3)
49
+ end = self.end_transitions.unsqueeze(0).expand(batch, -1, -1).reshape(BT, 3)
50
+
51
+ arange = torch.arange(BT, device=em.device)
52
+ score = start + em[:, 0]
53
+ history: list[torch.Tensor] = []
54
+
55
+ for i in range(1, seq):
56
+ broadcast = score.unsqueeze(2) + trans + em[:, i].unsqueeze(1)
57
+ best_score, best_prev = broadcast.max(dim=1)
58
+ score = torch.where(mk[:, i].unsqueeze(1), best_score, score)
59
+ history.append(best_prev)
60
+
61
+ score = score + end
62
+ _, best_last = score.max(dim=1)
63
+
64
+ best_paths = torch.zeros(BT, seq, dtype=torch.long, device=em.device)
65
+ seq_lengths = mk.sum(dim=1).long()
66
+ best_paths[arange, seq_lengths - 1] = best_last
67
+
68
+ for i in range(seq - 2, -1, -1):
69
+ prev_tag = history[i][arange, best_paths[:, i + 1]]
70
+ should_update = i < (seq_lengths - 1)
71
+ best_paths[:, i] = torch.where(should_update, prev_tag, best_paths[:, i])
72
+
73
+ return best_paths.reshape(batch, num_types, seq).permute(0, 2, 1)
74
+
75
+
76
  class HavelockTokenClassifier(BertPreTrainedModel):
77
  """Multi-label BIO token classifier with independent O/B/I heads per marker type.
78
 
 
85
  def __init__(self, config):
86
  super().__init__(config)
87
  self.num_types = config.num_types
88
+ self.use_crf = getattr(config, "use_crf", False)
89
  self.bert = BertModel(config, add_pooling_layer=False)
90
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
91
  self.classifier = nn.Linear(config.hidden_size, config.num_types * 3)
92
+
93
+ if self.use_crf:
94
+ self.crf = MultiLabelCRF(config.num_types)
95
+
96
  self.post_init()
97
 
98
  def forward(self, input_ids, attention_mask=None, **kwargs):
 
102
  hidden = self.dropout(hidden)
103
  logits = self.classifier(hidden)
104
  batch, seq, _ = logits.shape
105
+ logits = logits.view(batch, seq, self.num_types, 3)
106
+
107
+ # If CRF is available and we're not training, return decoded tags
108
+ # stacked with logits so callers can access either
109
+ if self.use_crf and not self.training:
110
+ mask = (
111
+ attention_mask.bool()
112
+ if attention_mask is not None
113
+ else torch.ones(batch, seq, dtype=torch.bool, device=logits.device)
114
+ )
115
+ # Return logits — callers use .decode() or we add a decode method
116
+ # For HF pipeline compat, return logits; users call decode separately
117
+ pass
118
+
119
+ return logits
120
+
121
+ def decode(self, input_ids, attention_mask=None):
122
+ """Run forward pass and return Viterbi-decoded tags."""
123
+ logits = self.forward(input_ids, attention_mask)
124
+ if self.use_crf:
125
+ mask = (
126
+ attention_mask.bool()
127
+ if attention_mask is not None
128
+ else torch.ones(
129
+ logits.shape[:2], dtype=torch.bool, device=logits.device
130
+ )
131
+ )
132
+ return self.crf.decode(logits, mask)
133
+ return logits.argmax(dim=-1)