from transformers import PreTrainedModel from transformers import AutoModel import torch from torch.autograd import Function from .configuration_me2bert import ME2BertConfig class ReverseLayerF(Function): @staticmethod def forward(ctx, x, alpha): ctx.alpha = alpha return x.view_as(x) @staticmethod def backward(ctx, grad_output): output = grad_output.neg() * ctx.alpha return output, None class FFClassifier(torch.nn.Module): def __init__(self, input_dim, hidden_dim, n_classes, dropout=0.0): super(FFClassifier, self).__init__() self.model = torch.nn.Sequential( torch.nn.Linear(input_dim, hidden_dim), torch.nn.BatchNorm1d(hidden_dim), torch.nn.ReLU(True), torch.nn.Dropout(dropout), torch.nn.Linear(hidden_dim, n_classes)) def forward(self, x): return self.model(x) class Encoder(torch.nn.Module): def __init__(self, input_dim, hidden_dim, latent_dim): super(Encoder, self).__init__() self.fc1 = torch.nn.Linear(input_dim, hidden_dim, bias=True) self.fc2 = torch.nn.Linear(hidden_dim, latent_dim, bias=True) self.prelu = torch.nn.PReLU() def forward(self, x): x = self.prelu(self.fc1(x)) x = self.fc2(x) return x class Decoder(torch.nn.Module): def __init__(self, latent_dim, hidden_dim, output_dim): super(Decoder, self).__init__() self.fc1 = torch.nn.Linear(latent_dim, hidden_dim, bias=True) self.fc2 = torch.nn.Linear(hidden_dim, output_dim, bias=True) self.prelu = torch.nn.PReLU() def forward(self, x): x = self.prelu(self.fc1(x)) return self.fc2(x) class AutoEncoder(torch.nn.Module): def __init__(self, input_dim, hidden_dim, latent_dim): super(AutoEncoder, self).__init__() self.encoder = Encoder(input_dim, hidden_dim, latent_dim) self.layer_norm = torch.nn.LayerNorm(latent_dim) self.decoder = Decoder(latent_dim, hidden_dim, input_dim) def forward(self, x): encoded = self.encoder(x) encoded = self.layer_norm(encoded) decoded = self.decoder(encoded) decoded = decoded return encoded, decoded class GatedCombination(torch.nn.Module): def __init__(self, embedding_dim): super(GatedCombination, self).__init__() self.embedding_dim = embedding_dim self.forget_gate = torch.nn.Linear(embedding_dim, embedding_dim) self.input_gate = torch.nn.Linear(embedding_dim, embedding_dim) self.output_gate = torch.nn.Linear(embedding_dim, embedding_dim) self.sigmoid = torch.nn.Sigmoid() self.tanh = torch.nn.Tanh() def forward(self, frozen_output, finetuned_output): forget_gate = self.sigmoid(self.forget_gate(frozen_output)) input_gate = self.sigmoid(self.input_gate(finetuned_output)) combined = forget_gate * frozen_output + input_gate * finetuned_output output_gate = self.sigmoid(self.output_gate(combined)) gated_output = output_gate * self.tanh(combined) return gated_output class ME2BertModel(PreTrainedModel): config_class = ME2BertConfig base_model_prefix = "me2bert" def __init__( self, config: ME2BertConfig = None): if config is None: config = ME2BertConfig() super().__init__(config) self.n_mf_classes = 5 self.n_domain_classes = 2 pretrained_model_name = config.pretrained_model_name self.has_gate = config.has_gate self.has_trans = config.has_trans self.emotion_labels = [0, 0, 0, 0, 0] self.feature = AutoModel.from_pretrained(pretrained_model_name) self.bert_frozen = AutoModel.from_pretrained(pretrained_model_name) for param in self.bert_frozen.parameters(): param.requires_grad = False self.embedding_dim = self.feature.config.hidden_size latent_dim = 128 self.emotion_dim = 5 self.gated_combination = ( GatedCombination(embedding_dim=self.embedding_dim) ) self.trans_module = ( AutoEncoder(self.embedding_dim, 256, latent_dim)) initial_dim = self.embedding_dim + self.n_domain_classes + self.emotion_dim self.mf_classifier = FFClassifier( initial_dim, latent_dim, self.n_mf_classes, .0 ) self.domain_classifier = FFClassifier( self.embedding_dim, latent_dim, self.n_domain_classes, ) def gen_feature_embeddings(self, input_ids, attention_mask): feature = self.feature(input_ids=input_ids, attention_mask=attention_mask) return feature.last_hidden_state, feature.pooler_output def forward(self, input_ids, attention_mask, return_dict=False, **kwargs): _, pooler_output = self.gen_feature_embeddings( input_ids, attention_mask) with torch.no_grad(): frozen_output = self.bert_frozen(input_ids=input_ids, attention_mask=attention_mask) frozen_output = frozen_output.pooler_output device = pooler_output.device rec_embeddings = None if self.has_trans: rec_embeddings = pooler_output _, pooler_output = self.trans_module(rec_embeddings) if self.has_gate: gated_output = self.gated_combination(frozen_output, pooler_output) else: gated_output = pooler_output else: gated_output = pooler_output domain_labels = torch.zeros(gated_output.shape[0]).long().to(device) domain_feature = torch.nn.functional.one_hot( domain_labels, num_classes=self.n_domain_classes).squeeze(1) emotion_features = None if self.emotion_labels is not None: if isinstance(self.emotion_labels, list): emotion_tensor = torch.tensor(self.emotion_labels, dtype=torch.float32) emotion_features = emotion_tensor.repeat(gated_output.shape[0], 1) else: emotion_features = torch.nn.functional.one_hot( self.emotion_labels.long(), num_classes=self.emotion_dim ).squeeze(1) if emotion_features is not None: emotion_features = emotion_features[:gated_output.shape[0], :] class_output = torch.cat((gated_output.to(device), domain_feature.to(device), emotion_features.to(device)), dim=1) else: emotion_features = torch.zeros(gated_output.shape[0], self.emotion_dim).to(device) class_output = torch.cat((gated_output.to(device), domain_feature.to(device), emotion_features.to(device)), dim=1) class_output = torch.sigmoid(self.mf_classifier(class_output)) if return_dict: mft_dimensions = [ 'CH', 'FC', 'LB', 'AS', 'PD' ] result_list = [] for i in range(class_output.shape[0]): row_scores = [round(score.item(), 5) for score in class_output[i]] row_dict = dict(zip(mft_dimensions, row_scores)) result_list.append(row_dict) return result_list return class_output