Spaces:
Runtime error
Runtime error
| from copy import deepcopy | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| # from pytorch_pretrained_bert.modeling import BertModel | |
| from transformers import BertConfig, RobertaConfig, RobertaModel, BertModel | |
| class BertEncoder(nn.Module): | |
| def __init__(self, cfg): | |
| super(BertEncoder, self).__init__() | |
| self.cfg = cfg | |
| self.bert_name = cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE | |
| print("LANGUAGE BACKBONE USE GRADIENT CHECKPOINTING: ", self.cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT) | |
| if self.bert_name == "bert-base-uncased": | |
| config = BertConfig.from_pretrained(self.bert_name) | |
| config.gradient_checkpointing = self.cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT | |
| self.model = BertModel.from_pretrained(self.bert_name, add_pooling_layer=False, config=config) | |
| self.language_dim = 768 | |
| elif self.bert_name == "roberta-base": | |
| config = RobertaConfig.from_pretrained(self.bert_name) | |
| config.gradient_checkpointing = self.cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT | |
| self.model = RobertaModel.from_pretrained(self.bert_name, add_pooling_layer=False, config=config) | |
| self.language_dim = 768 | |
| else: | |
| raise NotImplementedError | |
| self.num_layers = cfg.MODEL.LANGUAGE_BACKBONE.N_LAYERS | |
| def forward(self, x): | |
| input = x["input_ids"] | |
| mask = x["attention_mask"] | |
| if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: | |
| # with padding, always 256 | |
| outputs = self.model( | |
| input_ids=input, | |
| attention_mask=mask, | |
| output_hidden_states=True, | |
| ) | |
| # outputs has 13 layers, 1 input layer and 12 hidden layers | |
| encoded_layers = outputs.hidden_states[1:] | |
| features = None | |
| features = torch.stack(encoded_layers[-self.num_layers:], 1).mean(1) | |
| # language embedding has shape [len(phrase), seq_len, language_dim] | |
| features = features / self.num_layers | |
| embedded = features * mask.unsqueeze(-1).float() | |
| aggregate = embedded.sum(1) / (mask.sum(-1).unsqueeze(-1).float()) | |
| else: | |
| # without padding, only consider positive_tokens | |
| max_len = (input != 0).sum(1).max().item() | |
| outputs = self.model( | |
| input_ids=input[:, :max_len], | |
| attention_mask=mask[:, :max_len], | |
| output_hidden_states=True, | |
| ) | |
| # outputs has 13 layers, 1 input layer and 12 hidden layers | |
| encoded_layers = outputs.hidden_states[1:] | |
| features = None | |
| features = torch.stack(encoded_layers[-self.num_layers:], 1).mean(1) | |
| # language embedding has shape [len(phrase), seq_len, language_dim] | |
| features = features / self.num_layers | |
| embedded = features * mask[:, :max_len].unsqueeze(-1).float() | |
| aggregate = embedded.sum(1) / (mask.sum(-1).unsqueeze(-1).float()) | |
| ret = { | |
| "aggregate": aggregate, | |
| "embedded": embedded, | |
| "masks": mask, | |
| "hidden": encoded_layers[-1] | |
| } | |
| return ret | |