| from transformers import PreTrainedModel |
| from transformers import AutoModel |
| import torch |
| from torch.autograd import Function |
| from .configuration_me2bert import ME2BertConfig |
|
|
| class ReverseLayerF(Function): |
|
|
| @staticmethod |
| def forward(ctx, x, alpha): |
| ctx.alpha = alpha |
|
|
| return x.view_as(x) |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| output = grad_output.neg() * ctx.alpha |
|
|
| return output, None |
|
|
|
|
| class FFClassifier(torch.nn.Module): |
|
|
| def __init__(self, input_dim, hidden_dim, n_classes, dropout=0.0): |
| super(FFClassifier, self).__init__() |
|
|
| self.model = torch.nn.Sequential( |
| torch.nn.Linear(input_dim, hidden_dim), |
| torch.nn.BatchNorm1d(hidden_dim), torch.nn.ReLU(True), |
| torch.nn.Dropout(dropout), torch.nn.Linear(hidden_dim, n_classes)) |
|
|
| def forward(self, x): |
| return self.model(x) |
|
|
|
|
| class Encoder(torch.nn.Module): |
|
|
| def __init__(self, input_dim, hidden_dim, latent_dim): |
| super(Encoder, self).__init__() |
| self.fc1 = torch.nn.Linear(input_dim, hidden_dim, bias=True) |
| self.fc2 = torch.nn.Linear(hidden_dim, latent_dim, bias=True) |
| self.prelu = torch.nn.PReLU() |
|
|
| def forward(self, x): |
| x = self.prelu(self.fc1(x)) |
| x = self.fc2(x) |
| return x |
|
|
|
|
| class Decoder(torch.nn.Module): |
| def __init__(self, latent_dim, hidden_dim, output_dim): |
| super(Decoder, self).__init__() |
| self.fc1 = torch.nn.Linear(latent_dim, hidden_dim, bias=True) |
| self.fc2 = torch.nn.Linear(hidden_dim, output_dim, bias=True) |
| self.prelu = torch.nn.PReLU() |
|
|
| def forward(self, x): |
| x = self.prelu(self.fc1(x)) |
| return self.fc2(x) |
|
|
|
|
| class AutoEncoder(torch.nn.Module): |
| def __init__(self, input_dim, hidden_dim, latent_dim): |
| super(AutoEncoder, self).__init__() |
| self.encoder = Encoder(input_dim, hidden_dim, latent_dim) |
| self.layer_norm = torch.nn.LayerNorm(latent_dim) |
| self.decoder = Decoder(latent_dim, hidden_dim, input_dim) |
|
|
| def forward(self, x): |
| encoded = self.encoder(x) |
| encoded = self.layer_norm(encoded) |
| decoded = self.decoder(encoded) |
| decoded = decoded |
| return encoded, decoded |
|
|
|
|
| class GatedCombination(torch.nn.Module): |
| def __init__(self, embedding_dim): |
| super(GatedCombination, self).__init__() |
| self.embedding_dim = embedding_dim |
|
|
| self.forget_gate = torch.nn.Linear(embedding_dim, embedding_dim) |
| self.input_gate = torch.nn.Linear(embedding_dim, embedding_dim) |
| self.output_gate = torch.nn.Linear(embedding_dim, embedding_dim) |
|
|
| self.sigmoid = torch.nn.Sigmoid() |
| self.tanh = torch.nn.Tanh() |
|
|
| def forward(self, frozen_output, finetuned_output): |
| forget_gate = self.sigmoid(self.forget_gate(frozen_output)) |
| input_gate = self.sigmoid(self.input_gate(finetuned_output)) |
|
|
| combined = forget_gate * frozen_output + input_gate * finetuned_output |
|
|
| output_gate = self.sigmoid(self.output_gate(combined)) |
|
|
| gated_output = output_gate * self.tanh(combined) |
|
|
| return gated_output |
|
|
|
|
| class ME2BertModel(PreTrainedModel): |
| config_class = ME2BertConfig |
| base_model_prefix = "me2bert" |
| def __init__( |
| self, |
| config: ME2BertConfig = None): |
| if config is None: |
| config = ME2BertConfig() |
|
|
| super().__init__(config) |
| self.n_mf_classes = 5 |
| self.n_domain_classes = 2 |
| pretrained_model_name = config.pretrained_model_name |
| self.has_gate = config.has_gate |
| self.has_trans = config.has_trans |
| self.emotion_labels = [0, 0, 0, 0, 0] |
| self.feature = AutoModel.from_pretrained(pretrained_model_name) |
| self.bert_frozen = AutoModel.from_pretrained(pretrained_model_name) |
|
|
| for param in self.bert_frozen.parameters(): |
| param.requires_grad = False |
|
|
| self.embedding_dim = self.feature.config.hidden_size |
| latent_dim = 128 |
| self.emotion_dim = 5 |
|
|
| self.gated_combination = ( |
| GatedCombination(embedding_dim=self.embedding_dim) |
| ) |
|
|
| self.trans_module = ( |
| AutoEncoder(self.embedding_dim, 256, latent_dim)) |
|
|
| initial_dim = self.embedding_dim + self.n_domain_classes + self.emotion_dim |
|
|
| self.mf_classifier = FFClassifier( |
| initial_dim, latent_dim, self.n_mf_classes, .0 |
| ) |
|
|
| self.domain_classifier = FFClassifier( |
| self.embedding_dim, latent_dim, self.n_domain_classes, |
|
|
| ) |
|
|
| def gen_feature_embeddings(self, input_ids, attention_mask): |
| feature = self.feature(input_ids=input_ids, attention_mask=attention_mask) |
| return feature.last_hidden_state, feature.pooler_output |
|
|
| def forward(self, |
| input_ids, |
| attention_mask, return_dict=False, **kwargs): |
|
|
| _, pooler_output = self.gen_feature_embeddings( |
| input_ids, attention_mask) |
|
|
| with torch.no_grad(): |
| frozen_output = self.bert_frozen(input_ids=input_ids, attention_mask=attention_mask) |
|
|
| frozen_output = frozen_output.pooler_output |
|
|
| device = pooler_output.device |
| rec_embeddings = None |
| if self.has_trans: |
| rec_embeddings = pooler_output |
| _, pooler_output = self.trans_module(rec_embeddings) |
| if self.has_gate: |
| gated_output = self.gated_combination(frozen_output, pooler_output) |
| else: |
| gated_output = pooler_output |
| else: |
| gated_output = pooler_output |
|
|
| domain_labels = torch.zeros(gated_output.shape[0]).long().to(device) |
| domain_feature = torch.nn.functional.one_hot( |
| domain_labels, num_classes=self.n_domain_classes).squeeze(1) |
|
|
| emotion_features = None |
| if self.emotion_labels is not None: |
| if isinstance(self.emotion_labels, list): |
| emotion_tensor = torch.tensor(self.emotion_labels, dtype=torch.float32) |
| emotion_features = emotion_tensor.repeat(gated_output.shape[0], 1) |
| else: |
| emotion_features = torch.nn.functional.one_hot( |
| self.emotion_labels.long(), num_classes=self.emotion_dim |
| ).squeeze(1) |
|
|
| if emotion_features is not None: |
| emotion_features = emotion_features[:gated_output.shape[0], :] |
| class_output = torch.cat((gated_output, domain_feature, emotion_features), dim=1) |
|
|
| else: |
| emotion_features = torch.zeros(gated_output.shape[0], self.emotion_dim).to(device) |
| class_output = torch.cat((gated_output, domain_feature, emotion_features), dim=1) |
|
|
| class_output = torch.sigmoid(self.mf_classifier(class_output)) |
| if return_dict: |
| mft_dimensions = [ |
| 'CH', |
| 'FC', |
| 'LB', |
| 'AS', |
| 'PD' |
| ] |
|
|
| result_list = [] |
| for i in range(class_output.shape[0]): |
| row_scores = [round(score.item(), 5) for score in class_output[i]] |
| row_dict = dict(zip(mft_dimensions, row_scores)) |
| result_list.append(row_dict) |
| return result_list |
|
|
| return class_output |
|
|