|
|
from operator import mod |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import nn |
|
|
from transformers import BertPreTrainedModel, BertModel, AutoModelForMaskedLM, BertForMaskedLM |
|
|
from torch.nn.parameter import Parameter |
|
|
from .utils import PairEnum |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from losses import SupConLoss |
|
|
|
|
|
activation_map = {'relu': nn.ReLU(), 'tanh': nn.Tanh()} |
|
|
|
|
|
class Bert_SCCL(BertPreTrainedModel): |
|
|
def __init__(self, config, args): |
|
|
super(Bert_SCCL, self).__init__(config) |
|
|
self.bert = None |
|
|
self.contrast_head = None |
|
|
self.cluster_centers = None |
|
|
|
|
|
def init_model(self, cluster_centers=None, alpha=1.0): |
|
|
self.emb_size = self.bert.config.hidden_size |
|
|
self.alpha = alpha |
|
|
|
|
|
|
|
|
self.contrast_head = nn.Sequential( |
|
|
nn.Linear(self.emb_size, self.emb_size), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Linear(self.emb_size, 128)) |
|
|
|
|
|
|
|
|
initial_cluster_centers = torch.tensor( |
|
|
cluster_centers, dtype=torch.float, requires_grad=True) |
|
|
self.cluster_centers = Parameter(initial_cluster_centers) |
|
|
|
|
|
def forward(self, input_ids, attention_mask, task_type): |
|
|
|
|
|
if task_type == "evaluate": |
|
|
return self.get_mean_embeddings(input_ids, attention_mask) |
|
|
|
|
|
elif task_type == "explicit": |
|
|
input_ids_1, input_ids_2, input_ids_3 = torch.unbind(input_ids, dim=1) |
|
|
attention_mask_1, attention_mask_2, attention_mask_3 = torch.unbind(attention_mask, dim=1) |
|
|
|
|
|
mean_output_1 = self.get_mean_embeddings(input_ids_1, attention_mask_1) |
|
|
mean_output_2 = self.get_mean_embeddings(input_ids_2, attention_mask_2) |
|
|
mean_output_3 = self.get_mean_embeddings(input_ids_3, attention_mask_3) |
|
|
|
|
|
return mean_output_1, mean_output_2, mean_output_3 |
|
|
|
|
|
def get_mean_embeddings(self, input_ids, attention_mask): |
|
|
bert_output = self.bert.forward(input_ids=input_ids, attention_mask=attention_mask) |
|
|
attention_mask = attention_mask.unsqueeze(-1) |
|
|
mean_output = torch.sum(bert_output[0]*attention_mask, dim=1) / torch.sum(attention_mask, dim=1) |
|
|
return mean_output |
|
|
|
|
|
def get_cluster_prob(self, embeddings): |
|
|
norm_squared = torch.sum((embeddings.unsqueeze(1) - self.cluster_centers) ** 2, 2) |
|
|
numerator = 1.0 / (1.0 + (norm_squared / self.alpha)) |
|
|
power = float(self.alpha + 1) / 2 |
|
|
numerator = numerator ** power |
|
|
return numerator / torch.sum(numerator, dim=1, keepdim=True) |
|
|
|
|
|
def local_consistency(self, embd0, embd1, embd2, criterion): |
|
|
p0 = self.get_cluster_prob(embd0) |
|
|
p1 = self.get_cluster_prob(embd1) |
|
|
p2 = self.get_cluster_prob(embd2) |
|
|
|
|
|
lds1 = criterion(p1, p0) |
|
|
lds2 = criterion(p2, p0) |
|
|
return lds1+lds2 |
|
|
|
|
|
def contrast_logits(self, embd1, embd2=None): |
|
|
feat1 = F.normalize(self.contrast_head(embd1), dim=1) |
|
|
if embd2 != None: |
|
|
feat2 = F.normalize(self.contrast_head(embd2), dim=1) |
|
|
return feat1, feat2 |
|
|
else: |
|
|
return feat1 |
|
|
|
|
|
class BERT_MTP_Pretrain(nn.Module): |
|
|
|
|
|
def __init__(self, args): |
|
|
|
|
|
super(BERT_MTP_Pretrain, self).__init__() |
|
|
self.num_labels = args.num_labels |
|
|
self.bert = AutoModelForMaskedLM.from_pretrained(args.pretrained_bert_model) |
|
|
self.dropout = nn.Dropout(0.1) |
|
|
self.classifier = nn.Linear(args.feat_dim, args.num_labels) |
|
|
|
|
|
def forward(self, X, ): |
|
|
|
|
|
outputs = self.bert(**X, output_hidden_states=True) |
|
|
|
|
|
CLSEmbedding = outputs.hidden_states[-1][:,0] |
|
|
CLSEmbedding = self.dropout(CLSEmbedding) |
|
|
logits = self.classifier(CLSEmbedding) |
|
|
output_dir = {"logits": logits} |
|
|
output_dir["hidden_states"] = outputs.hidden_states[-1][:, 0] |
|
|
|
|
|
return output_dir |
|
|
|
|
|
def mlmForward(self, X, Y = None): |
|
|
outputs = self.bert(**X, labels = Y) |
|
|
return outputs.loss |
|
|
|
|
|
def loss_ce(self, logits, Y): |
|
|
loss = nn.CrossEntropyLoss() |
|
|
output = loss(logits, Y) |
|
|
return output |
|
|
|
|
|
class BERT_MTP(nn.Module): |
|
|
def __init__(self, args): |
|
|
super(BERT_MTP, self).__init__() |
|
|
|
|
|
self.bert = AutoModelForMaskedLM.from_pretrained(args.pretrained_bert_model) |
|
|
self.dropout = nn.Dropout(0.1) |
|
|
|
|
|
self.head = nn.Sequential( |
|
|
nn.Linear(args.feat_dim, args.feat_dim), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(args.feat_dim, args.mlp_head_feat_dim) |
|
|
) |
|
|
|
|
|
def forward(self, X): |
|
|
"""logits are not normalized by softmax in forward function""" |
|
|
outputs = self.bert(**X, output_hidden_states=True, output_attentions=True) |
|
|
cls_embed = outputs.hidden_states[-1][:,0] |
|
|
features = F.normalize(self.head(cls_embed), dim=1) |
|
|
output_dir = {"features": features} |
|
|
output_dir["hidden_states"] = cls_embed |
|
|
|
|
|
return output_dir |
|
|
|
|
|
def loss_cl(self, embds, label=None, mask=None, temperature=0.07, base_temperature=0.07, device=None): |
|
|
"""compute contrastive loss""" |
|
|
loss = SupConLoss() |
|
|
output = loss(embds, labels=label, mask=mask, temperature = temperature, device=device) |
|
|
return output |
|
|
|
|
|
def save_backbone(self, save_path): |
|
|
self.bert.save_pretrained(save_path) |
|
|
|
|
|
class BERT_GCD(BertPreTrainedModel): |
|
|
|
|
|
def __init__(self,config, args): |
|
|
|
|
|
super(BERT_GCD, self).__init__(config) |
|
|
self.num_labels = args.num_labels |
|
|
self.bert = BertModel(config) |
|
|
self.mlp_head = nn.Sequential( |
|
|
nn.Linear(args.feat_dim, args.feat_dim), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Linear(args.feat_dim, args.mlp_head_feat_dim) |
|
|
) |
|
|
self.init_weights() |
|
|
|
|
|
def forward(self, input_ids = None, token_type_ids = None, attention_mask=None , labels = None, |
|
|
feature_ext = False, mode = None, loss_fct = None): |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True) |
|
|
|
|
|
encoded_layer_12 = outputs.hidden_states |
|
|
last_output_tokens = encoded_layer_12[-1] |
|
|
features = last_output_tokens.mean(dim = 1) |
|
|
|
|
|
return features |
|
|
|
|
|
class BERT_CC(BertPreTrainedModel): |
|
|
|
|
|
def __init__(self,config, args): |
|
|
|
|
|
super(BERT_CC, self).__init__(config) |
|
|
self.num_labels = args.num_labels |
|
|
self.bert = BertModel(config) |
|
|
self.cluster_num = args.num_labels |
|
|
|
|
|
self.instance_projector = nn.Sequential( |
|
|
nn.Linear(config.hidden_size, config.hidden_size), |
|
|
nn.ReLU(), |
|
|
nn.Linear(config.hidden_size, config.hidden_size), |
|
|
) |
|
|
|
|
|
self.cluster_projector = nn.Sequential( |
|
|
nn.Linear(config.hidden_size, config.hidden_size), |
|
|
nn.ReLU(), |
|
|
nn.Linear(config.hidden_size, self.cluster_num), |
|
|
nn.Softmax(dim=1) |
|
|
) |
|
|
|
|
|
self.init_weights() |
|
|
|
|
|
def get_features(self, h_i, h_j): |
|
|
|
|
|
z_i = F.normalize(self.instance_projector(h_i), dim=1) |
|
|
z_j = F.normalize(self.instance_projector(h_j), dim=1) |
|
|
|
|
|
c_i = self.cluster_projector(h_i) |
|
|
c_j = self.cluster_projector(h_j) |
|
|
|
|
|
return z_i, z_j, c_i, c_j |
|
|
|
|
|
def forward_cluster(self, x): |
|
|
|
|
|
c = self.cluster_projector(x) |
|
|
c = torch.argmax(c, dim=1) |
|
|
return c |
|
|
|
|
|
def forward(self, input_ids = None, token_type_ids = None, attention_mask=None , labels = None, |
|
|
feature_ext = False, mode = None, loss_fct = None): |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True) |
|
|
|
|
|
encoded_layer_12 = outputs.hidden_states |
|
|
last_output_tokens = encoded_layer_12[-1] |
|
|
features = last_output_tokens.mean(dim = 1) |
|
|
|
|
|
return features |
|
|
|
|
|
class BERTForDeepAligned(BertPreTrainedModel): |
|
|
|
|
|
def __init__(self,config, args): |
|
|
|
|
|
super(BERTForDeepAligned, self).__init__(config) |
|
|
self.num_labels = args.num_labels |
|
|
self.bert = BertModel(config) |
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.activation = activation_map[args.activation] |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
self.classifier = nn.Linear(config.hidden_size, args.num_labels) |
|
|
self.init_weights() |
|
|
|
|
|
def forward(self, input_ids = None, token_type_ids = None, attention_mask=None , labels = None, |
|
|
feature_ext = False, mode = None, loss_fct = None): |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True) |
|
|
encoded_layer_12 = outputs.hidden_states |
|
|
pooled_output = outputs.pooler_output |
|
|
|
|
|
pooled_output = self.dense(encoded_layer_12[-1].mean(dim = 1)) |
|
|
pooled_output = self.activation(pooled_output) |
|
|
pooled_output = self.dropout(pooled_output) |
|
|
|
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
if feature_ext: |
|
|
return pooled_output |
|
|
else: |
|
|
if mode == 'train': |
|
|
|
|
|
loss = loss_fct(logits, labels) |
|
|
return loss |
|
|
else: |
|
|
return pooled_output, logits |
|
|
|
|
|
class BERT_USNID(BertPreTrainedModel): |
|
|
|
|
|
def __init__(self, config, args): |
|
|
|
|
|
super(BERT_USNID, self).__init__(config) |
|
|
self.num_labels = args.num_labels |
|
|
self.bert = BertModel(config) |
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.activation = activation_map[args.activation] |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
self.args = args |
|
|
|
|
|
if args.pretrain or (not args.wo_self): |
|
|
self.classifier = nn.Linear(config.hidden_size, args.num_labels) |
|
|
|
|
|
self.mlp_head = nn.Linear(config.hidden_size, args.num_labels) |
|
|
|
|
|
self.init_weights() |
|
|
|
|
|
def forward(self, input_ids = None, token_type_ids = None, attention_mask=None , feature_ext = False): |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True) |
|
|
|
|
|
encoded_layer_12 = outputs.hidden_states |
|
|
last_output_tokens = encoded_layer_12[-1] |
|
|
features = last_output_tokens.mean(dim = 1) |
|
|
|
|
|
features = self.dense(features) |
|
|
pooled_output = self.activation(features) |
|
|
pooled_output = self.dropout(features) |
|
|
|
|
|
if self.args.pretrain or (not self.args.wo_self): |
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
mlp_outputs = self.mlp_head(pooled_output) |
|
|
|
|
|
if feature_ext: |
|
|
if self.args.pretrain or (not self.args.wo_self): |
|
|
return features, logits |
|
|
else: |
|
|
return features, mlp_outputs |
|
|
|
|
|
else: |
|
|
if self.args.pretrain or (not self.args.wo_self): |
|
|
return mlp_outputs, logits |
|
|
else: |
|
|
return mlp_outputs, mlp_outputs |
|
|
|
|
|
class BERT_USNID_UNSUP(BertPreTrainedModel): |
|
|
|
|
|
def __init__(self, config, args): |
|
|
|
|
|
super(BERT_USNID_UNSUP, self).__init__(config) |
|
|
self.num_labels = args.num_labels |
|
|
self.bert = BertModel(config) |
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.activation = activation_map[args.activation] |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
self.args = args |
|
|
|
|
|
self.classifier = nn.Linear(config.hidden_size, args.num_labels) |
|
|
self.mlp_head = nn.Linear(config.hidden_size, args.num_labels) |
|
|
|
|
|
self.init_weights() |
|
|
|
|
|
def forward(self, input_ids = None, token_type_ids = None, attention_mask=None , labels = None, weights = None, |
|
|
feature_ext = False, mode = None, loss_fct = None, aug_feats=None, use_aug = False): |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True) |
|
|
|
|
|
encoded_layer_12 = outputs.hidden_states |
|
|
last_output_tokens = encoded_layer_12[-1] |
|
|
features = last_output_tokens.mean(dim = 1) |
|
|
|
|
|
features = self.dense(features) |
|
|
pooled_output = self.activation(features) |
|
|
pooled_output = self.dropout(features) |
|
|
|
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
mlp_outputs = self.mlp_head(pooled_output) |
|
|
|
|
|
if feature_ext: |
|
|
return features, mlp_outputs |
|
|
else: |
|
|
return mlp_outputs, logits |
|
|
|
|
|
class BertForConstrainClustering(BertPreTrainedModel): |
|
|
def __init__(self, config, args): |
|
|
super(BertForConstrainClustering, self).__init__(config) |
|
|
self.num_labels = args.num_labels |
|
|
self.bert = BertModel(config) |
|
|
|
|
|
|
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.activation = activation_map[args.activation] |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
self.classifier = nn.Linear(config.hidden_size, args.num_labels) |
|
|
self.init_weights() |
|
|
|
|
|
self.alpha = 1.0 |
|
|
self.cluster_layer = Parameter(torch.Tensor(args.num_labels, args.num_labels)) |
|
|
torch.nn.init.xavier_normal_(self.cluster_layer.data) |
|
|
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, |
|
|
feature_ext = False, u_threshold=None, l_threshold=None, mode=None, semi=False): |
|
|
|
|
|
eps = 1e-10 |
|
|
outputs = self.bert( |
|
|
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True) |
|
|
encoded_layer_12 = outputs.hidden_states |
|
|
pooled_output = outputs.pooler_output |
|
|
pooled_output = self.dense(encoded_layer_12[-1].mean(dim = 1)) |
|
|
pooled_output = self.activation(pooled_output) |
|
|
pooled_output = self.dropout(pooled_output) |
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
if feature_ext: |
|
|
return logits |
|
|
else: |
|
|
if mode=='train': |
|
|
|
|
|
logits_norm = F.normalize(logits, p=2, dim=1) |
|
|
sim_mat = torch.matmul(logits_norm, logits_norm.transpose(0, -1)) |
|
|
label_mat = labels.view(-1,1) - labels.view(1,-1) |
|
|
label_mat[label_mat!=0] = -1 |
|
|
label_mat[label_mat==0] = 1 |
|
|
label_mat[label_mat==-1] = 0 |
|
|
|
|
|
if not semi: |
|
|
pos_mask = (label_mat > u_threshold).type(torch.cuda.FloatTensor) |
|
|
neg_mask = (label_mat < l_threshold).type(torch.cuda.FloatTensor) |
|
|
pos_entropy = -torch.log(torch.clamp(sim_mat, eps, 1.0)) * pos_mask |
|
|
neg_entropy = -torch.log(torch.clamp(1-sim_mat, eps, 1.0)) * neg_mask |
|
|
loss = (pos_entropy.mean() + neg_entropy.mean()) * 5 |
|
|
|
|
|
return loss |
|
|
|
|
|
else: |
|
|
label_mat[labels==-1, :] = -1 |
|
|
label_mat[:, labels==-1] = -1 |
|
|
label_mat[label_mat==0] = 0 |
|
|
label_mat[label_mat==1] = 1 |
|
|
pos_mask = (sim_mat > u_threshold).type(torch.cuda.FloatTensor) |
|
|
neg_mask = (sim_mat < l_threshold).type(torch.cuda.FloatTensor) |
|
|
pos_mask[label_mat==1] = 1 |
|
|
neg_mask[label_mat==0] = 1 |
|
|
pos_entropy = -torch.log(torch.clamp(sim_mat, eps, 1.0)) * pos_mask |
|
|
neg_entropy = -torch.log(torch.clamp(1-sim_mat, eps, 1.0)) * neg_mask |
|
|
loss = pos_entropy.mean() + neg_entropy.mean() + u_threshold - l_threshold |
|
|
|
|
|
return loss |
|
|
|
|
|
else: |
|
|
q = 1.0 / (1.0 + torch.sum(torch.pow(logits.unsqueeze(1) - self.cluster_layer, 2), 2) / self.alpha) |
|
|
q = q.pow((self.alpha + 1.0) / 2.0) |
|
|
q = (q.t() / torch.sum(q, 1)).t() |
|
|
return logits, q |
|
|
|
|
|
class BertForDTC(BertPreTrainedModel): |
|
|
def __init__(self, config, args): |
|
|
|
|
|
super(BertForDTC, self).__init__(config) |
|
|
self.num_labels = args.num_labels |
|
|
self.bert = BertModel(config) |
|
|
|
|
|
|
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.activation = activation_map[args.activation] |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
self.classifier = nn.Linear(config.hidden_size, args.num_labels) |
|
|
self.init_weights() |
|
|
|
|
|
|
|
|
self.alpha = 1.0 |
|
|
self.cluster_layer = Parameter(torch.Tensor(args.num_labels, args.num_labels)) |
|
|
torch.nn.init.xavier_normal_(self.cluster_layer.data) |
|
|
|
|
|
def forward(self, input_ids = None, token_type_ids = None, attention_mask=None , labels = None, |
|
|
feature_ext = False, mode = None, loss_fct=None): |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True) |
|
|
encoded_layer_12 = outputs.hidden_states |
|
|
pooled_output = outputs.pooler_output |
|
|
pooled_output = self.dense(encoded_layer_12[-1].mean(dim = 1)) |
|
|
pooled_output = self.activation(pooled_output) |
|
|
pooled_output = self.dropout(pooled_output) |
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
if feature_ext: |
|
|
return pooled_output |
|
|
elif mode == 'train': |
|
|
loss = loss_fct(logits, labels) |
|
|
return loss |
|
|
else: |
|
|
q = 1.0 / (1.0 + torch.sum(torch.pow(logits.unsqueeze(1) - self.cluster_layer, 2), 2) / self.alpha) |
|
|
q = q.pow((self.alpha + 1.0) / 2.0) |
|
|
q = (q.t() / torch.sum(q, 1)).t() |
|
|
return logits, q |
|
|
|
|
|
class BertForKCL_Similarity(BertPreTrainedModel): |
|
|
def __init__(self, config, args): |
|
|
super(BertForKCL_Similarity,self).__init__(config) |
|
|
|
|
|
self.num_labels = args.num_labels |
|
|
self.bert = BertModel(config) |
|
|
|
|
|
self.dense = nn.Linear(config.hidden_size * 2, config.hidden_size * 4) |
|
|
self.normalization = nn.BatchNorm1d(config.hidden_size * 4) |
|
|
self.activation = activation_map[args.activation] |
|
|
|
|
|
self.classifier = nn.Linear(config.hidden_size * 4, args.num_labels) |
|
|
self.init_weights() |
|
|
|
|
|
def forward(self, input_ids, token_type_ids = None, attention_mask=None, labels=None, loss_fct=None, mode = None): |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True) |
|
|
encoded_layer_12 = outputs.hidden_states |
|
|
pooled_output = outputs.pooler_output |
|
|
feat1,feat2 = PairEnum(encoded_layer_12[-1].mean(dim = 1)) |
|
|
feature_cat = torch.cat([feat1,feat2], 1) |
|
|
|
|
|
pooled_output = self.dense(feature_cat) |
|
|
pooled_output = self.normalization(pooled_output) |
|
|
pooled_output = self.activation(pooled_output) |
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
if mode == 'train': |
|
|
loss = loss_fct(logits.view(-1,self.num_labels), labels.view(-1)) |
|
|
|
|
|
return loss |
|
|
else: |
|
|
return pooled_output, logits |
|
|
|
|
|
class BertForKCL(BertPreTrainedModel): |
|
|
def __init__(self, config, args): |
|
|
super(BertForKCL, self).__init__(config) |
|
|
|
|
|
self.num_labels = args.num_labels |
|
|
self.bert = BertModel(config) |
|
|
|
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.activation = activation_map[args.activation] |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
self.classifier = nn.Linear(config.hidden_size, args.num_labels) |
|
|
self.init_weights() |
|
|
|
|
|
def forward(self, input_ids = None, token_type_ids = None, attention_mask=None , labels = None, mode = None, |
|
|
simi = None, loss_fct = None): |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True) |
|
|
encoded_layer_12 = outputs.hidden_states |
|
|
pooled_output = outputs.pooler_output |
|
|
pooled_output = self.dense(encoded_layer_12[-1].mean(dim = 1)) |
|
|
pooled_output = self.activation(pooled_output) |
|
|
pooled_output = self.dropout(pooled_output) |
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
if mode == 'train': |
|
|
|
|
|
probs = F.softmax(logits,dim=1) |
|
|
prob1, prob2 = PairEnum(probs) |
|
|
|
|
|
loss_KCL = loss_fct(prob1, prob2, simi) |
|
|
flag = len(labels[labels != -1]) |
|
|
|
|
|
if flag != 0: |
|
|
loss_ce = nn.CrossEntropyLoss()(logits[labels != -1], labels[labels != -1]) |
|
|
loss = loss_ce + loss_KCL |
|
|
else: |
|
|
loss = loss_KCL |
|
|
|
|
|
return loss |
|
|
else: |
|
|
return pooled_output, logits |
|
|
|
|
|
class BertForMCL(BertPreTrainedModel): |
|
|
def __init__(self, config, args): |
|
|
super(BertForMCL, self).__init__(config) |
|
|
|
|
|
self.num_labels = args.num_labels |
|
|
self.bert = BertModel(config) |
|
|
|
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.activation = activation_map[args.activation] |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
self.classifier = nn.Linear(config.hidden_size, args.num_labels) |
|
|
self.init_weights() |
|
|
|
|
|
def forward(self, input_ids = None, token_type_ids = None, attention_mask=None , labels = None, mode = None, loss_fct = None): |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True) |
|
|
encoded_layer_12 = outputs.hidden_states |
|
|
pooled_output = outputs.pooler_output |
|
|
pooled_output = self.dense(encoded_layer_12[-1].mean(dim = 1)) |
|
|
pooled_output = self.activation(pooled_output) |
|
|
pooled_output = self.dropout(pooled_output) |
|
|
logits = self.classifier(pooled_output) |
|
|
probs = F.softmax(logits, dim = 1) |
|
|
|
|
|
if mode == 'train': |
|
|
|
|
|
flag = len(labels[labels != -1]) |
|
|
prob1, prob2 = PairEnum(probs) |
|
|
simi = torch.matmul(probs, probs.transpose(0, -1)).view(-1) |
|
|
|
|
|
simi[simi > 0.5] = 1 |
|
|
simi[simi < 0.5] = -1 |
|
|
loss_MCL = loss_fct(prob1, prob2, simi) |
|
|
|
|
|
if flag != 0: |
|
|
|
|
|
loss_ce = nn.CrossEntropyLoss()(logits[labels != -1], labels[labels != -1]) |
|
|
loss = loss_ce + loss_MCL |
|
|
|
|
|
else: |
|
|
loss = loss_MCL |
|
|
|
|
|
return loss |
|
|
|
|
|
else: |
|
|
return pooled_output, logits |