Spaces:
Runtime error
Runtime error
| import torch | |
| import flair | |
| START_TAG: str = "<START>" | |
| STOP_TAG: str = "<STOP>" | |
| class CRF(torch.nn.Module): | |
| """ | |
| Conditional Random Field Implementation according to sgrvinod and modified to not | |
| only look at the current word, but also on the previously seen annotation. | |
| """ | |
| def __init__(self, tag_dictionary, tagset_size: int, init_from_state_dict: bool): | |
| """ | |
| :param tag_dictionary: tag dictionary in order to find ID for start and stop tags | |
| :param tagset_size: number of tag from tag dictionary | |
| :param init_from_state_dict: whether we load pretrained model from state dict | |
| """ | |
| super(CRF, self).__init__() | |
| self.tagset_size = tagset_size | |
| # Transitions are used in the following way: transitions[to, from]. | |
| self.transitions = torch.nn.Parameter(torch.randn(tagset_size, tagset_size)) | |
| # If we are not using a pretrained model and train a fresh one, we need to set transitions from any tag | |
| # to START-tag and from STOP-tag to any other tag to -10000. | |
| if not init_from_state_dict: | |
| self.transitions.detach()[tag_dictionary.get_idx_for_item(START_TAG), :] = -10000 | |
| self.transitions.detach()[:, tag_dictionary.get_idx_for_item(STOP_TAG)] = -10000 | |
| self.to(flair.device) | |
| def forward(self, features: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward propagation of Conditional Random Field. | |
| :param features: output from LSTM Layer in shape (batch size, seq len, hidden size) | |
| :return: CRF scores (emission scores for each token + transitions prob from previous state) in | |
| shape (batch_size, seq len, tagset size, tagset size) | |
| """ | |
| batch_size, seq_len = features.size()[:2] | |
| emission_scores = features | |
| emission_scores = emission_scores.unsqueeze(-1).expand(batch_size, seq_len, self.tagset_size, self.tagset_size) | |
| crf_scores = emission_scores + self.transitions.unsqueeze(0).unsqueeze(0) | |
| return crf_scores |