Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import numpy as np | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.runner import BaseModule | |
| from mmocr.models.builder import DECODERS | |
| class FCDecoder(BaseModule): | |
| """FC Decoder class for Ner. | |
| Args: | |
| num_labels (int): Number of categories mapped by entity label. | |
| hidden_dropout_prob (float): The dropout probability of hidden layer. | |
| hidden_size (int): Hidden layer output layer channels. | |
| """ | |
| def __init__(self, | |
| num_labels=None, | |
| hidden_dropout_prob=0.1, | |
| hidden_size=768, | |
| init_cfg=[ | |
| dict(type='Xavier', layer='Conv2d'), | |
| dict(type='Uniform', layer='BatchNorm2d') | |
| ]): | |
| super().__init__(init_cfg=init_cfg) | |
| self.num_labels = num_labels | |
| self.dropout = nn.Dropout(hidden_dropout_prob) | |
| self.classifier = nn.Linear(hidden_size, self.num_labels) | |
| def forward(self, outputs): | |
| sequence_output = outputs[0] | |
| sequence_output = self.dropout(sequence_output) | |
| logits = self.classifier(sequence_output) | |
| softmax = F.softmax(logits, dim=2) | |
| preds = softmax.detach().cpu().numpy() | |
| preds = np.argmax(preds, axis=2).tolist() | |
| return logits, preds | |