permutans commited on
Commit
8019718
·
verified ·
1 Parent(s): 621c79f

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_havelock.py +78 -0
modeling_havelock.py CHANGED
@@ -5,6 +5,84 @@ import torch.nn as nn
5
  from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
6
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  class HavelockTokenConfig(PretrainedConfig):
9
  """Config that wraps any backbone config + our custom fields."""
10
 
 
5
  from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
6
 
7
 
8
+ class MultiLabelCRF(nn.Module):
9
+ """Independent CRF per marker type for multi-label BIO tagging."""
10
+
11
+ def __init__(self, num_types: int) -> None:
12
+ super().__init__()
13
+ self.num_types = num_types
14
+ self.transitions = nn.Parameter(torch.empty(num_types, 3, 3))
15
+ self.start_transitions = nn.Parameter(torch.empty(num_types, 3))
16
+ self.end_transitions = nn.Parameter(torch.empty(num_types, 3))
17
+ # Placeholder — will be overwritten by loaded weights if present
18
+ self.register_buffer("emission_bias", torch.zeros(1, 1, 1, 3))
19
+ self._reset_parameters()
20
+
21
+ def _reset_parameters(self) -> None:
22
+ nn.init.uniform_(self.transitions, -0.1, 0.1)
23
+ nn.init.uniform_(self.start_transitions, -0.1, 0.1)
24
+ nn.init.uniform_(self.end_transitions, -0.1, 0.1)
25
+ with torch.no_grad():
26
+ self.transitions.data[:, 0, 2] = -10000.0
27
+ self.start_transitions.data[:, 2] = -10000.0
28
+
29
+ def _apply_emission_bias(self, emissions: torch.Tensor) -> torch.Tensor:
30
+ if self.emission_bias is not None:
31
+ return emissions + self.emission_bias
32
+
33
+ return emissions
34
+
35
+ def decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
36
+ """Viterbi decoding.
37
+
38
+ Args:
39
+ emissions: (batch, seq, num_types, 3)
40
+ mask: (batch, seq) boolean
41
+
42
+ Returns: (batch, seq, num_types) best tag sequences
43
+ """
44
+ # Apply emission bias before decoding
45
+ emissions = self._apply_emission_bias(emissions)
46
+ batch, seq, num_types, _ = emissions.shape
47
+
48
+ # Reshape to (batch*num_types, seq, 3)
49
+ em = emissions.permute(0, 2, 1, 3).reshape(batch * num_types, seq, 3)
50
+ mk = mask.unsqueeze(1).expand(-1, num_types, -1).reshape(batch * num_types, seq)
51
+ BT = batch * num_types
52
+
53
+ # Expand params across batch
54
+ trans = (
55
+ self.transitions.unsqueeze(0).expand(batch, -1, -1, -1).reshape(BT, 3, 3)
56
+ )
57
+
58
+ start = self.start_transitions.unsqueeze(0).expand(batch, -1, -1).reshape(BT, 3)
59
+ end = self.end_transitions.unsqueeze(0).expand(batch, -1, -1).reshape(BT, 3)
60
+
61
+ arange = torch.arange(BT, device=em.device)
62
+ score = start + em[:, 0]
63
+ history: list[torch.Tensor] = []
64
+
65
+ for i in range(1, seq):
66
+ broadcast = score.unsqueeze(2) + trans + em[:, i].unsqueeze(1)
67
+ best_score, best_prev = broadcast.max(dim=1)
68
+ score = torch.where(mk[:, i].unsqueeze(1), best_score, score)
69
+ history.append(best_prev)
70
+
71
+ score = score + end
72
+ _, best_last = score.max(dim=1)
73
+
74
+ best_paths = torch.zeros(BT, seq, dtype=torch.long, device=em.device)
75
+ seq_lengths = mk.sum(dim=1).long()
76
+ best_paths[arange, seq_lengths - 1] = best_last
77
+
78
+ for i in range(seq - 2, -1, -1):
79
+ prev_tag = history[i][arange, best_paths[:, i + 1]]
80
+ should_update = i < (seq_lengths - 1)
81
+ best_paths[:, i] = torch.where(should_update, prev_tag, best_paths[:, i])
82
+
83
+ return best_paths.reshape(batch, num_types, seq).permute(0, 2, 1)
84
+
85
+
86
  class HavelockTokenConfig(PretrainedConfig):
87
  """Config that wraps any backbone config + our custom fields."""
88