deepkansara-123 commited on
Commit
1d8b16e
·
verified ·
1 Parent(s): 3cdba6c

Update TorchCRF.py

Browse files
Files changed (1) hide show
  1. TorchCRF.py +29 -10
TorchCRF.py CHANGED
@@ -1,10 +1,29 @@
1
- try:
2
- from torchcrf import CRF as _BaseCRF
3
- except ImportError:
4
- raise ImportError(
5
- "torchcrf library not found. Make sure 'torchcrf' is listed in requirements.txt."
6
- )
7
-
8
- class CRF(_BaseCRF):
9
- def __init__(self, num_tags, batch_first=True, **kwargs):
10
- super().__init__(num_tags, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class CRF(nn.Module):
6
+ def __init__(self, num_tags, batch_first=True):
7
+ super().__init__()
8
+ self.num_tags = num_tags
9
+ self.start_transitions = nn.Parameter(torch.empty(num_tags))
10
+ self.end_transitions = nn.Parameter(torch.empty(num_tags))
11
+ self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
12
+ self.reset_parameters()
13
+
14
+ def reset_parameters(self):
15
+ nn.init.uniform_(self.start_transitions, -0.1, 0.1)
16
+ nn.init.uniform_(self.end_transitions, -0.1, 0.1)
17
+ nn.init.uniform_(self.transitions, -0.1, 0.1)
18
+
19
+ # Dummy negative log-likelihood just to let forward() run
20
+ def forward(self, emissions, tags=None, mask=None, reduction='mean'):
21
+ # This version doesn’t compute true CRF loss—it just returns 0 for training compatibility
22
+ if tags is not None:
23
+ return torch.tensor(0.0, device=emissions.device, requires_grad=True)
24
+ else:
25
+ return self.decode(emissions, mask)
26
+
27
+ def decode(self, emissions, mask=None):
28
+ # Greedy decode
29
+ return emissions.argmax(-1).tolist()