Upload folder using huggingface_hub
Browse files- modeling_havelock.py +10 -0
modeling_havelock.py
CHANGED
|
@@ -14,6 +14,8 @@ class MultiLabelCRF(nn.Module):
|
|
| 14 |
self.transitions = nn.Parameter(torch.empty(num_types, 3, 3))
|
| 15 |
self.start_transitions = nn.Parameter(torch.empty(num_types, 3))
|
| 16 |
self.end_transitions = nn.Parameter(torch.empty(num_types, 3))
|
|
|
|
|
|
|
| 17 |
self._reset_parameters()
|
| 18 |
|
| 19 |
def _reset_parameters(self) -> None:
|
|
@@ -24,6 +26,11 @@ class MultiLabelCRF(nn.Module):
|
|
| 24 |
self.transitions.data[:, 0, 2] = -10000.0
|
| 25 |
self.start_transitions.data[:, 2] = -10000.0
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
def decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
| 28 |
"""Viterbi decoding.
|
| 29 |
|
|
@@ -33,6 +40,9 @@ class MultiLabelCRF(nn.Module):
|
|
| 33 |
|
| 34 |
Returns: (batch, seq, num_types) best tag sequences
|
| 35 |
"""
|
|
|
|
|
|
|
|
|
|
| 36 |
batch, seq, num_types, _ = emissions.shape
|
| 37 |
|
| 38 |
# Reshape to (batch*num_types, seq, 3)
|
|
|
|
| 14 |
self.transitions = nn.Parameter(torch.empty(num_types, 3, 3))
|
| 15 |
self.start_transitions = nn.Parameter(torch.empty(num_types, 3))
|
| 16 |
self.end_transitions = nn.Parameter(torch.empty(num_types, 3))
|
| 17 |
+
# Placeholder — will be overwritten by loaded weights if present
|
| 18 |
+
self.register_buffer("emission_bias", None)
|
| 19 |
self._reset_parameters()
|
| 20 |
|
| 21 |
def _reset_parameters(self) -> None:
|
|
|
|
| 26 |
self.transitions.data[:, 0, 2] = -10000.0
|
| 27 |
self.start_transitions.data[:, 2] = -10000.0
|
| 28 |
|
| 29 |
+
def _apply_emission_bias(self, emissions: torch.Tensor) -> torch.Tensor:
|
| 30 |
+
if self.emission_bias is not None:
|
| 31 |
+
return emissions + self.emission_bias
|
| 32 |
+
return emissions
|
| 33 |
+
|
| 34 |
def decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
| 35 |
"""Viterbi decoding.
|
| 36 |
|
|
|
|
| 40 |
|
| 41 |
Returns: (batch, seq, num_types) best tag sequences
|
| 42 |
"""
|
| 43 |
+
# Apply emission bias before decoding
|
| 44 |
+
emissions = self._apply_emission_bias(emissions)
|
| 45 |
+
|
| 46 |
batch, seq, num_types, _ = emissions.shape
|
| 47 |
|
| 48 |
# Reshape to (batch*num_types, seq, 3)
|