permutans commited on
Commit
91ec653
·
verified ·
1 Parent(s): d64c032

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_havelock.py +10 -0
modeling_havelock.py CHANGED
@@ -14,6 +14,8 @@ class MultiLabelCRF(nn.Module):
14
  self.transitions = nn.Parameter(torch.empty(num_types, 3, 3))
15
  self.start_transitions = nn.Parameter(torch.empty(num_types, 3))
16
  self.end_transitions = nn.Parameter(torch.empty(num_types, 3))
 
 
17
  self._reset_parameters()
18
 
19
  def _reset_parameters(self) -> None:
@@ -24,6 +26,11 @@ class MultiLabelCRF(nn.Module):
24
  self.transitions.data[:, 0, 2] = -10000.0
25
  self.start_transitions.data[:, 2] = -10000.0
26
 
 
 
 
 
 
27
  def decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
28
  """Viterbi decoding.
29
 
@@ -33,6 +40,9 @@ class MultiLabelCRF(nn.Module):
33
 
34
  Returns: (batch, seq, num_types) best tag sequences
35
  """
 
 
 
36
  batch, seq, num_types, _ = emissions.shape
37
 
38
  # Reshape to (batch*num_types, seq, 3)
 
14
  self.transitions = nn.Parameter(torch.empty(num_types, 3, 3))
15
  self.start_transitions = nn.Parameter(torch.empty(num_types, 3))
16
  self.end_transitions = nn.Parameter(torch.empty(num_types, 3))
17
+ # Placeholder — will be overwritten by loaded weights if present
18
+ self.register_buffer("emission_bias", None)
19
  self._reset_parameters()
20
 
21
  def _reset_parameters(self) -> None:
 
26
  self.transitions.data[:, 0, 2] = -10000.0
27
  self.start_transitions.data[:, 2] = -10000.0
28
 
29
+ def _apply_emission_bias(self, emissions: torch.Tensor) -> torch.Tensor:
30
+ if self.emission_bias is not None:
31
+ return emissions + self.emission_bias
32
+ return emissions
33
+
34
  def decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
35
  """Viterbi decoding.
36
 
 
40
 
41
  Returns: (batch, seq, num_types) best tag sequences
42
  """
43
+ # Apply emission bias before decoding
44
+ emissions = self._apply_emission_bias(emissions)
45
+
46
  batch, seq, num_types, _ = emissions.shape
47
 
48
  # Reshape to (batch*num_types, seq, 3)