varox34's picture
Upload 64 files
366b225 verified
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from transformers import BertModel
from .scalar_mix import ScalarMix
class BertEmbedding(nn.Module):
def __init__(self, model, n_layers, n_out, requires_grad=False):
proxies = {
"http": "http://10.10.1.10:3128",
"https": "https://10.10.1.10:1080",
}
super(BertEmbedding, self).__init__()
#self.bert = AutoModelForMaskedLM.from_pretrained("Sanath369/distilroberta-base-finetuned-telugu_bert1")
self.bert = BertModel.from_pretrained(model, output_hidden_states=True)
self.bert = self.bert.requires_grad_(requires_grad)
self.n_layers = n_layers
self.n_out = n_out
self.requires_grad = requires_grad
self.hidden_size = self.bert.config.hidden_size
self.scalar_mix = ScalarMix(n_layers)
self.projection = nn.Linear(self.hidden_size, n_out, False)
def __repr__(self):
s = self.__class__.__name__ + '('
s += f"n_layers={self.n_layers}, n_out={self.n_out}"
if self.requires_grad:
s += f", requires_grad={self.requires_grad}"
s += ')'
return s
def forward(self, subwords, bert_lens, bert_mask):
batch_size, seq_len = bert_lens.shape
mask = bert_lens.gt(0)
if not self.requires_grad:
self.bert.eval()
# print(subwords)
out = self.bert(subwords, attention_mask=bert_mask)
# print(out[0].shape)
# print(out[1].shape)
# print("bert_mask:", bert_mask)
_,_,bert = self.bert(subwords, attention_mask=bert_mask)
bert = bert[-self.n_layers:]
# print("first" , bert)
bert = self.scalar_mix(bert)
# print("Second" , bert)
bert = bert[bert_mask].split(bert_lens[mask].tolist())
bert = torch.stack([i.mean(0) for i in bert])
bert_embed = bert.new_zeros(batch_size, seq_len, self.hidden_size)
bert_embed = bert_embed.masked_scatter_(mask.unsqueeze(-1), bert)
bert_embed = self.projection(bert_embed)
return bert_embed