Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from transformers import * | |
| HIDDEN_OUTPUT_FEATURES = 768 | |
| TRAINED_WEIGHTS = 'bert-base-uncased' | |
| NUM_CLASSES = 2 # entity, not entity | |
| class EntityBertNet(nn.Module): | |
| def __init__(self): | |
| super(EntityBertNet, self).__init__() | |
| config = BertConfig.from_pretrained(TRAINED_WEIGHTS) | |
| self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config) | |
| self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_CLASSES) | |
| def forward(self, input_ids, attn_mask, entity_indices): | |
| # BERT | |
| bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask,return_dict=False) | |
| # max pooling at entity locations | |
| entity_pooled_output = EntityBertNet.pooled_output(bert_output, entity_indices) | |
| # fc layer (softmax activation done in loss function) | |
| x = self.fc(entity_pooled_output) | |
| return x | |
| def pooled_output(bert_output, indices): | |
| #print(bert_output) | |
| outputs = torch.gather(input=bert_output, dim=1, index=indices) | |
| pooled_output, _ = torch.max(outputs, dim=1) | |
| return pooled_output |