File size: 538 Bytes
fff452e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from torchcrf import CRF
import torch.nn as nn

class CRF_Tagger(nn.Module):
    def __init__(self, input_dim, num_tags):
        super().__init__()
        self.embed2tag = nn.Linear(input_dim, num_tags)
        self.crf = CRF(num_tags, batch_first=True)
    
    def forward(self, x, labels, mask):
        emissions = self.embed2tag(x)
        return -self.crf(emissions, labels, mask=mask, reduction="mean")
    
    def decode(self, x, mask=None):
        emissions = self.embed2tag(x)
        return self.crf.decode(emissions, mask)