File size: 2,377 Bytes
0b1042e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torch.nn as nn
from transformers import *

HIDDEN_OUTPUT_FEATURES = 768
TRAINED_WEIGHTS = 'bert-base-uncased'
NUM_CLASSES = 3  # no relation, fst hasFeature snd, snd hasFeature fst
HIDDEN_ENTITY_FEATURES = 6  # lower -> more general but less informative entity representations


class PairBertNet(nn.Module):

    def __init__(self):
        super(PairBertNet, self).__init__()
        # self.entity_fc1 = nn.Linear(HIDDEN_OUTPUT_FEATURES, HIDDEN_ENTITY_FEATURES)
        # self.entity_fc2 = nn.Linear(HIDDEN_ENTITY_FEATURES, HIDDEN_OUTPUT_FEATURES)
        config = BertConfig.from_pretrained(TRAINED_WEIGHTS)
        self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)
        self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES * 2, NUM_CLASSES)

    def forward(self, input_ids, attn_mask, masked_indices, fst_indices, snd_indices):
        # embeddings = self.bert_base.get_input_embeddings()
        # input_embeddings = embeddings(input_ids)
        #
        # # get partially masked input_embeddings for entity terms
        # unmasked_entity_embeddings = input_embeddings[masked_indices[:, 0], masked_indices[:, 1]]
        # hidden_entity_repr = torch.tanh(self.entity_fc1(unmasked_entity_embeddings))
        # masked_entity_embeddings = torch.repeat_interleave(hidden_entity_repr, 128, dim=1)  # 768 / 12 = 64
        #
        # # replace input_embeddings with partially masked ones for entities
        # input_embeddings[masked_indices[:, 0], masked_indices[:, 1]] = masked_entity_embeddings

        # BERT
        bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask,return_dict=False)

        # max pooling at entity locations
        fst_pooled_output = PairBertNet.pooled_output(bert_output, fst_indices)
        snd_pooled_output = PairBertNet.pooled_output(bert_output, snd_indices)

        # concat pooled outputs from prod and feat entities
        combined = torch.cat((fst_pooled_output, snd_pooled_output), dim=1)

        # fc layer (softmax activation done in loss function)
        x = self.fc(combined)
        return x

    @staticmethod

    def pooled_output(bert_output, indices):
        print(bert_output.shape,indices.shape)
        outputs = torch.gather(bert_output, dim=1, index=indices)
        pooled_output, _ = torch.max(outputs, dim=1)
        return pooled_output