File size: 1,162 Bytes
0b1042e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
from transformers import *

HIDDEN_OUTPUT_FEATURES = 768
TRAINED_WEIGHTS = 'bert-base-uncased'
NUM_CLASSES = 2  # entity, not entity


class EntityBertNet(nn.Module):

    def __init__(self):
        super(EntityBertNet, self).__init__()
        config = BertConfig.from_pretrained(TRAINED_WEIGHTS)
        self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)
        self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_CLASSES)

    def forward(self, input_ids, attn_mask, entity_indices):
        # BERT
        bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask,return_dict=False)

        # max pooling at entity locations
        entity_pooled_output = EntityBertNet.pooled_output(bert_output, entity_indices)

        # fc layer (softmax activation done in loss function)
        x = self.fc(entity_pooled_output)
        return x

    @staticmethod
    def pooled_output(bert_output, indices):
        #print(bert_output)
        outputs = torch.gather(input=bert_output, dim=1, index=indices)
        pooled_output, _ = torch.max(outputs, dim=1)
        return pooled_output