Spaces:
Sleeping
Sleeping
File size: 538 Bytes
fff452e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
from torchcrf import CRF
import torch.nn as nn
class CRF_Tagger(nn.Module):
def __init__(self, input_dim, num_tags):
super().__init__()
self.embed2tag = nn.Linear(input_dim, num_tags)
self.crf = CRF(num_tags, batch_first=True)
def forward(self, x, labels, mask):
emissions = self.embed2tag(x)
return -self.crf(emissions, labels, mask=mask, reduction="mean")
def decode(self, x, mask=None):
emissions = self.embed2tag(x)
return self.crf.decode(emissions, mask) |