GopalGoyal's picture
start
0b1042e
import torch
import torch.nn as nn
from transformers import *
HIDDEN_OUTPUT_FEATURES = 768
TRAINED_WEIGHTS = 'bert-base-uncased'
NUM_CLASSES = 2 # entity, not entity
class EntityBertNet(nn.Module):
def __init__(self):
super(EntityBertNet, self).__init__()
config = BertConfig.from_pretrained(TRAINED_WEIGHTS)
self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)
self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_CLASSES)
def forward(self, input_ids, attn_mask, entity_indices):
# BERT
bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask,return_dict=False)
# max pooling at entity locations
entity_pooled_output = EntityBertNet.pooled_output(bert_output, entity_indices)
# fc layer (softmax activation done in loss function)
x = self.fc(entity_pooled_output)
return x
@staticmethod
def pooled_output(bert_output, indices):
#print(bert_output)
outputs = torch.gather(input=bert_output, dim=1, index=indices)
pooled_output, _ = torch.max(outputs, dim=1)
return pooled_output