Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from transformers import * | |
| HIDDEN_OUTPUT_FEATURES = 768 | |
| TRAINED_WEIGHTS = 'bert-base-uncased' | |
| NUM_CLASSES = 3 # no relation, fst hasFeature snd, snd hasFeature fst | |
| HIDDEN_ENTITY_FEATURES = 6 # lower -> more general but less informative entity representations | |
| class PairBertNet(nn.Module): | |
| def __init__(self): | |
| super(PairBertNet, self).__init__() | |
| # self.entity_fc1 = nn.Linear(HIDDEN_OUTPUT_FEATURES, HIDDEN_ENTITY_FEATURES) | |
| # self.entity_fc2 = nn.Linear(HIDDEN_ENTITY_FEATURES, HIDDEN_OUTPUT_FEATURES) | |
| config = BertConfig.from_pretrained(TRAINED_WEIGHTS) | |
| self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config) | |
| self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES * 2, NUM_CLASSES) | |
| def forward(self, input_ids, attn_mask, masked_indices, fst_indices, snd_indices): | |
| # embeddings = self.bert_base.get_input_embeddings() | |
| # input_embeddings = embeddings(input_ids) | |
| # | |
| # # get partially masked input_embeddings for entity terms | |
| # unmasked_entity_embeddings = input_embeddings[masked_indices[:, 0], masked_indices[:, 1]] | |
| # hidden_entity_repr = torch.tanh(self.entity_fc1(unmasked_entity_embeddings)) | |
| # masked_entity_embeddings = torch.repeat_interleave(hidden_entity_repr, 128, dim=1) # 768 / 12 = 64 | |
| # | |
| # # replace input_embeddings with partially masked ones for entities | |
| # input_embeddings[masked_indices[:, 0], masked_indices[:, 1]] = masked_entity_embeddings | |
| # BERT | |
| bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask,return_dict=False) | |
| # max pooling at entity locations | |
| fst_pooled_output = PairBertNet.pooled_output(bert_output, fst_indices) | |
| snd_pooled_output = PairBertNet.pooled_output(bert_output, snd_indices) | |
| # concat pooled outputs from prod and feat entities | |
| combined = torch.cat((fst_pooled_output, snd_pooled_output), dim=1) | |
| # fc layer (softmax activation done in loss function) | |
| x = self.fc(combined) | |
| return x | |
| def pooled_output(bert_output, indices): | |
| print(bert_output.shape,indices.shape) | |
| outputs = torch.gather(bert_output, dim=1, index=indices) | |
| pooled_output, _ = torch.max(outputs, dim=1) | |
| return pooled_output | |