|
|
import torch |
|
|
import math |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
|
|
|
from torch import nn |
|
|
from torch.nn import CrossEntropyLoss, MSELoss |
|
|
from torch.nn.parameter import Parameter |
|
|
from transformers import BertPreTrainedModel, BertModel, BertForMaskedLM, AutoConfig |
|
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
|
|
from .utils import ConvexSampler |
|
|
|
|
|
activation_map = {'relu': nn.ReLU(), 'tanh': nn.Tanh()} |
|
|
|
|
|
class BERT_DOC(BertPreTrainedModel): |
|
|
def __init__(self, config, args): |
|
|
super(BERT_DOC, self).__init__(config) |
|
|
self.num_labels = args.num_labels |
|
|
self.bert = BertModel(config) |
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.activation = activation_map[args.activation] |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
self.classifier = nn.Linear(config.hidden_size, args.num_labels) |
|
|
self.init_weights() |
|
|
|
|
|
def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None, |
|
|
feature_ext=False, mode=None, loss_fct=None, centroids = None): |
|
|
outputs = self.bert( |
|
|
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True) |
|
|
encoded_layer_12 = outputs.hidden_states |
|
|
|
|
|
pooled_output = self.dense(encoded_layer_12[-1].mean(dim=1)) |
|
|
pooled_output = self.dropout(pooled_output) |
|
|
pooled_output = self.activation(pooled_output) |
|
|
|
|
|
logits = self.classifier(pooled_output) |
|
|
logits = self.dropout(logits) |
|
|
sigmoid = nn.Sigmoid() |
|
|
logits = sigmoid(logits) |
|
|
|
|
|
if feature_ext: |
|
|
return pooled_output |
|
|
else: |
|
|
if mode == 'train': |
|
|
target = F.one_hot(labels, num_classes = self.num_labels) |
|
|
loss_bce = loss_fct(logits, target.float()) |
|
|
return loss_bce |
|
|
else: |
|
|
return pooled_output, logits |
|
|
|
|
|
class BERT(BertPreTrainedModel): |
|
|
def __init__(self, config, args): |
|
|
super(BERT, self).__init__(config) |
|
|
self.num_labels = args.num_labels |
|
|
self.bert = BertModel(config) |
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.activation = activation_map[args.activation] |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
self.classifier = nn.Linear(config.hidden_size, args.num_labels) |
|
|
self.init_weights() |
|
|
|
|
|
def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None, |
|
|
feature_ext=False, mode=None, loss_fct=None, centroids = None): |
|
|
outputs = self.bert( |
|
|
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True) |
|
|
encoded_layer_12 = outputs.hidden_states |
|
|
|
|
|
pooled_output = self.dense(encoded_layer_12[-1].mean(dim=1)) |
|
|
pooled_output = self.activation(pooled_output) |
|
|
pooled_output = self.dropout(pooled_output) |
|
|
|
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
if feature_ext: |
|
|
return pooled_output |
|
|
else: |
|
|
if mode == 'train': |
|
|
loss_ce = loss_fct(logits, labels) |
|
|
return loss_ce |
|
|
else: |
|
|
return pooled_output, logits |
|
|
|
|
|
class BERT_Norm(BertPreTrainedModel): |
|
|
def __init__(self, config, args): |
|
|
|
|
|
super(BERT_Norm, self).__init__(config) |
|
|
self.num_labels = args.num_labels |
|
|
self.bert = BertModel(config) |
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
self.init_weights() |
|
|
self.weight = Parameter(torch.FloatTensor(args.num_labels, args.feat_dim).to(args.device)) |
|
|
nn.init.xavier_uniform_(self.weight) |
|
|
|
|
|
def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None, |
|
|
feature_ext=False, mode=None, loss_fct=None, device = None, head = None): |
|
|
outputs = self.bert( |
|
|
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True) |
|
|
encoded_layer_12 = outputs.hidden_states |
|
|
pooled_output = encoded_layer_12[-1].mean(dim=1) |
|
|
pooled_output = self.dropout(pooled_output) |
|
|
pooled_output = F.normalize(pooled_output) |
|
|
|
|
|
logits = F.linear(pooled_output, F.normalize(self.weight)) |
|
|
logits = F.softmax(logits, dim = 1) |
|
|
|
|
|
if feature_ext: |
|
|
return pooled_output |
|
|
else: |
|
|
if mode == 'train': |
|
|
loss = loss_fct(logits, labels) |
|
|
return loss |
|
|
else: |
|
|
return pooled_output, logits |
|
|
|
|
|
class BERT_K_1_way(BertPreTrainedModel): |
|
|
def __init__(self, config, args): |
|
|
super(BERT_K_1_way, self).__init__(config) |
|
|
self.num_labels = args.num_labels |
|
|
self.bert = BertModel(config) |
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.activation = activation_map[args.activation] |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
self.sampler = ConvexSampler(args) |
|
|
self.classifier = nn.Linear(config.hidden_size, self.num_labels + 1) |
|
|
self.t = args.temp |
|
|
self.init_weights() |
|
|
|
|
|
def forward(self, input_ids = None, token_type_ids = None, attention_mask=None , labels = None, |
|
|
feature_ext = False, mode = None, loss_fct = None): |
|
|
outputs = self.bert( |
|
|
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True) |
|
|
encoded_layer_12 = outputs.hidden_states |
|
|
pooled_output = self.dense(encoded_layer_12[-1].mean(dim=1)) |
|
|
|
|
|
if mode is not 'test': |
|
|
pooled_output, labels = self.sampler(pooled_output, labels, mode=mode) |
|
|
|
|
|
pooled_output = self.activation(pooled_output) |
|
|
pooled_output = self.dropout(pooled_output) |
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
if feature_ext: |
|
|
return pooled_output |
|
|
else: |
|
|
if mode == 'train': |
|
|
loss = loss_fct(torch.div(logits, self.t), labels) |
|
|
return loss |
|
|
else: |
|
|
return pooled_output, logits, labels |
|
|
|
|
|
class BERT_SEG(BertPreTrainedModel): |
|
|
def __init__(self, config, args): |
|
|
|
|
|
super(BERT_SEG, self).__init__(config) |
|
|
self.num_labels = args.num_labels |
|
|
self.bert = BertModel(config) |
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.activation = activation_map[args.activation] |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
self.init_weights() |
|
|
|
|
|
self.alpha = args.alpha |
|
|
self.lambda_ = args.lambda_ |
|
|
self.means = nn.Parameter(torch.randn(self.num_labels, args.feat_dim).cuda()) |
|
|
nn.init.xavier_uniform_(self.means, gain=math.sqrt(2.0)) |
|
|
|
|
|
|
|
|
def forward(self, input_ids = None, token_type_ids = None, attention_mask=None , labels = None, |
|
|
feature_ext = False, mode = None, device=None, p_y = None, class_emb=None, loss_fct=None): |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True) |
|
|
encoded_layer_12 = outputs.hidden_states |
|
|
pooled_output = self.dense(encoded_layer_12[-1].mean(dim=1)) |
|
|
pooled_output = self.activation(pooled_output) |
|
|
pooled_output = self.dropout(pooled_output) |
|
|
|
|
|
if feature_ext: |
|
|
return pooled_output |
|
|
else: |
|
|
|
|
|
batch_size = pooled_output.shape[0] |
|
|
|
|
|
XY = torch.matmul(pooled_output, torch.transpose(self.means, 0, 1)) |
|
|
XX = torch.sum(pooled_output ** 2, dim=1, keepdim=True) |
|
|
YY = torch.sum(torch.transpose(self.means, 0, 1)**2, dim=0, keepdim=True) |
|
|
neg_sqr_dist = - 0.5 * (XX - 2.0 * XY + YY) |
|
|
|
|
|
|
|
|
|
|
|
p_y = p_y.expand_as(neg_sqr_dist).to(device) |
|
|
dist_exp = torch.exp(neg_sqr_dist) |
|
|
dist_exp_py = p_y.mul(dist_exp) |
|
|
dist_exp_sum = torch.sum(dist_exp_py, dim=1, keepdim=True) |
|
|
logits = dist_exp_py / dist_exp_sum |
|
|
|
|
|
if mode == 'train': |
|
|
|
|
|
labels_reshped = labels.view(labels.size()[0], -1) |
|
|
ALPHA = torch.zeros(batch_size, self.num_labels).to(device).scatter_(1, labels_reshped, self.alpha) |
|
|
K = ALPHA + torch.ones([batch_size, self.num_labels]).to(device) |
|
|
|
|
|
|
|
|
dist_margin = torch.mul(neg_sqr_dist, K) |
|
|
dist_margin_exp = torch.exp(dist_margin) |
|
|
dist_margin_exp_py = p_y.mul(dist_margin_exp) |
|
|
dist_exp_sum_margin = torch.sum(dist_margin_exp_py, dim=1, keepdim=True) |
|
|
likelihood = dist_margin_exp_py / dist_exp_sum_margin |
|
|
loss_ce = - likelihood.log().sum() / batch_size |
|
|
|
|
|
|
|
|
means = self.means if class_emb is None else class_emb |
|
|
means_batch = torch.index_select(means, dim=0, index=labels) |
|
|
loss_gen = (torch.sum((pooled_output - means_batch)**2) / 2) * (1. / batch_size) |
|
|
|
|
|
loss = loss_ce + self.lambda_ * loss_gen |
|
|
return loss |
|
|
|
|
|
else: |
|
|
return pooled_output, logits |
|
|
|
|
|
class CosNorm_Classifier(nn.Module): |
|
|
|
|
|
def __init__(self, in_dims, out_dims, scale=64, device = None): |
|
|
|
|
|
super(CosNorm_Classifier, self).__init__() |
|
|
self.in_dims = in_dims |
|
|
self.out_dims = out_dims |
|
|
self.scale = scale |
|
|
self.weight = Parameter(torch.Tensor(out_dims, in_dims).to(device)) |
|
|
self.reset_parameters() |
|
|
|
|
|
def reset_parameters(self): |
|
|
stdv = 1. / math.sqrt(self.weight.size(1)) |
|
|
self.weight.data.uniform_(-stdv, stdv) |
|
|
|
|
|
def forward(self, input, *args): |
|
|
norm_x = torch.norm(input, 2, 1, keepdim=True) |
|
|
ex = (norm_x / (1 + norm_x)) * (input / norm_x) |
|
|
ew = self.weight / torch.norm(self.weight, 2, 1, keepdim=True) |
|
|
return torch.mm(self.scale * ex, ew.t()) |
|
|
|
|
|
class BERT_Disaware(BertPreTrainedModel): |
|
|
|
|
|
def __init__(self, config, args): |
|
|
|
|
|
super(BERT_Disaware, self).__init__(config) |
|
|
self.num_labels = args.num_labels |
|
|
self.bert = BertModel(config) |
|
|
|
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.activation = nn.ReLU() |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
self.init_weights() |
|
|
|
|
|
self.cosnorm_classifier = CosNorm_Classifier( |
|
|
config.hidden_size, args.num_labels, args.scale, args.device) |
|
|
|
|
|
def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None, |
|
|
feature_ext=False, mode=None, loss_fct=None, centroids=None, dist_infos = None): |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, token_type_ids, attention_mask, output_hidden_states=True) |
|
|
encoded_layer_12 = outputs.hidden_states |
|
|
pooled_output = self.dense(encoded_layer_12[-1].mean(dim=1)) |
|
|
pooled_output = self.activation(pooled_output) |
|
|
pooled_output = self.dropout(pooled_output) |
|
|
x = pooled_output |
|
|
|
|
|
if feature_ext: |
|
|
return pooled_output |
|
|
|
|
|
else: |
|
|
|
|
|
feat_size = x.shape[1] |
|
|
batch_size = x.shape[0] |
|
|
|
|
|
f_expand = x.unsqueeze(1).expand(-1, self.num_labels, -1) |
|
|
centroids_expand = centroids.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
dist_cur = torch.norm(f_expand - centroids_expand, 2, 2) |
|
|
values_nn, labels_nn = torch.sort(dist_cur, 1) |
|
|
|
|
|
nearest_centers = centroids[labels_nn[:, 0]] |
|
|
dist_denominator = torch.norm(x - nearest_centers, 2, 1) |
|
|
second_nearest_centers = centroids[labels_nn[:, 1]] |
|
|
dist_numerator = torch.norm(x - second_nearest_centers, 2, 1) |
|
|
|
|
|
dist_info = dist_numerator - dist_denominator |
|
|
dist_info = torch.exp(dist_info) |
|
|
scalar = dist_info |
|
|
|
|
|
reachability = scalar.unsqueeze(1).expand(-1, feat_size) |
|
|
x = reachability * pooled_output |
|
|
|
|
|
logits = self.cosnorm_classifier(x) |
|
|
|
|
|
if mode == 'train': |
|
|
loss = loss_fct(logits, labels) |
|
|
return loss |
|
|
|
|
|
elif mode == 'eval': |
|
|
return pooled_output, logits |
|
|
|
|
|
class BERT_MDF_Pretrain(nn.Module): |
|
|
|
|
|
def __init__(self, args): |
|
|
|
|
|
super(BERT_MDF_Pretrain, self).__init__() |
|
|
self.num_labels = args.num_labels |
|
|
self.bert = BertForMaskedLM.from_pretrained(args.pretrained_bert_model) |
|
|
self.dropout = nn.Dropout(0.1) |
|
|
self.classifier = nn.Linear(args.feat_dim, args.num_labels) |
|
|
|
|
|
|
|
|
def forward(self, X): |
|
|
|
|
|
outputs = self.bert(**X, output_hidden_states=True) |
|
|
|
|
|
CLSEmbedding = outputs.hidden_states[-1][:,0] |
|
|
CLSEmbedding = self.dropout(CLSEmbedding) |
|
|
logits = self.classifier(CLSEmbedding) |
|
|
output_dir = {"logits": logits} |
|
|
output_dir["hidden_states"] = outputs.hidden_states[-1][:, 0] |
|
|
|
|
|
return output_dir |
|
|
|
|
|
def mlmForward(self, X, Y = None): |
|
|
outputs = self.bert(**X, labels = Y) |
|
|
return outputs.loss |
|
|
|
|
|
def loss_ce(self, logits, Y): |
|
|
loss = nn.CrossEntropyLoss() |
|
|
output = loss(logits, Y) |
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
class BERT_MDF(BertPreTrainedModel): |
|
|
def __init__(self, config, args): |
|
|
super(BERT_MDF, self).__init__(config) |
|
|
self.num_labels = args.num_labels |
|
|
self.bert = BertModel(config) |
|
|
self.dropout = nn.Dropout(0.1) |
|
|
self.classifier = nn.Linear(args.feat_dim, 2) |
|
|
self.init_weights() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids=None, |
|
|
attention_mask=None, |
|
|
token_type_ids=None, |
|
|
position_ids=None, |
|
|
head_mask=None, |
|
|
inputs_embeds=None, |
|
|
labels=None, |
|
|
): |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
head_mask=head_mask, |
|
|
output_hidden_states=True |
|
|
) |
|
|
|
|
|
|
|
|
pooled_output = outputs[1] |
|
|
|
|
|
pooled_output = self.dropout(pooled_output) |
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
outputs = (logits,) + outputs[ |
|
|
2: |
|
|
] |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
class BertClassificationHead(nn.Module): |
|
|
def __init__(self, config): |
|
|
super(BertClassificationHead, self).__init__() |
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
self.out_proj = nn.Linear(config.hidden_size, config.num_labels-1) |
|
|
|
|
|
def forward(self, feature): |
|
|
x = self.dropout(feature) |
|
|
x = self.dense(x) |
|
|
x = torch.tanh(x) |
|
|
x = self.dropout(x) |
|
|
x = self.out_proj(x) |
|
|
return x |
|
|
|
|
|
class BertContrastiveHead(nn.Module): |
|
|
def __init__(self, config): |
|
|
super(BertContrastiveHead, self).__init__() |
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
|
|
|
def forward(self, feature): |
|
|
x = self.dropout(feature) |
|
|
x = self.dense(x) |
|
|
x = torch.tanh(x) |
|
|
x = self.dropout(x) |
|
|
x = self.out_proj(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class BERT_KNNCL(nn.Module): |
|
|
|
|
|
def __init__(self, args): |
|
|
super(BERT_KNNCL, self).__init__() |
|
|
|
|
|
self.number_labels = args.anum_labels |
|
|
|
|
|
config = AutoConfig.from_pretrained( |
|
|
args.bert_model , |
|
|
num_labels=self.number_labels, |
|
|
) |
|
|
|
|
|
self.encoder_q = BertModel.from_pretrained(args.bert_model, config=config) |
|
|
self.encoder_k = BertModel.from_pretrained(args.bert_model, config=config) |
|
|
|
|
|
self.classifier_liner = BertClassificationHead(config) |
|
|
|
|
|
self.contrastive_liner_q = BertContrastiveHead(config) |
|
|
self.contrastive_liner_k = BertContrastiveHead(config) |
|
|
|
|
|
self.m = 0.999 |
|
|
self.T = args.temperature |
|
|
self.init_weights() |
|
|
self.contrastive_rate_in_training = args.contrastive_rate_in_training |
|
|
|
|
|
|
|
|
self.K = args.queue_size |
|
|
|
|
|
self.register_buffer("label_queue", torch.randint(0, self.number_labels, [self.K])) |
|
|
self.register_buffer("feature_queue", torch.randn(self.K, config.hidden_size)) |
|
|
self.feature_queue = torch.nn.functional.normalize(self.feature_queue, dim=0) |
|
|
|
|
|
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) |
|
|
self.top_k = args.top_k |
|
|
self.update_num = args.positive_num |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _dequeue_and_enqueue(self, keys, label): |
|
|
batch_size = keys.shape[0] |
|
|
|
|
|
ptr = int(self.queue_ptr) |
|
|
|
|
|
if ptr + batch_size > self.K: |
|
|
batch_size = self.K - ptr |
|
|
keys = keys[: batch_size] |
|
|
label = label[: batch_size] |
|
|
|
|
|
|
|
|
self.feature_queue[ptr: ptr + batch_size, :] = keys |
|
|
self.label_queue[ptr: ptr + batch_size] = label |
|
|
|
|
|
ptr = (ptr + batch_size) % self.K |
|
|
|
|
|
self.queue_ptr[0] = ptr |
|
|
|
|
|
def select_pos_neg_sample(self, liner_q, label_q): |
|
|
label_queue = self.label_queue.clone().detach() |
|
|
feature_queue = self.feature_queue.clone().detach() |
|
|
|
|
|
|
|
|
batch_size = label_q.shape[0] |
|
|
tmp_label_queue = label_queue.repeat([batch_size, 1]) |
|
|
tmp_feature_queue = feature_queue.unsqueeze(0) |
|
|
tmp_feature_queue = tmp_feature_queue.repeat([batch_size, 1, 1]) |
|
|
|
|
|
|
|
|
cos_sim = torch.einsum('nc,nkc->nk', [liner_q, tmp_feature_queue]) |
|
|
|
|
|
|
|
|
tmp_label = label_q.unsqueeze(1) |
|
|
tmp_label = tmp_label.repeat([1, self.K]) |
|
|
|
|
|
pos_mask_index = torch.eq(tmp_label_queue, tmp_label) |
|
|
neg_mask_index = ~ pos_mask_index |
|
|
|
|
|
|
|
|
feature_value = cos_sim.masked_select(neg_mask_index) |
|
|
neg_sample = torch.full_like(cos_sim, -np.inf).cuda() |
|
|
neg_sample = neg_sample.masked_scatter(neg_mask_index, feature_value) |
|
|
|
|
|
|
|
|
pos_mask_index = pos_mask_index.int() |
|
|
pos_number = pos_mask_index.sum(dim=-1) |
|
|
pos_min = pos_number.min() |
|
|
if pos_min == 0: |
|
|
return None |
|
|
pos_sample, _ = cos_sim.topk(pos_min, dim=-1) |
|
|
pos_sample_top_k = pos_sample[:, 0:self.top_k] |
|
|
pos_sample = pos_sample_top_k |
|
|
pos_sample = pos_sample.contiguous().view([-1, 1]) |
|
|
|
|
|
neg_mask_index = neg_mask_index.int() |
|
|
neg_number = neg_mask_index.sum(dim=-1) |
|
|
neg_min = neg_number.min() |
|
|
if neg_min == 0: |
|
|
return None |
|
|
neg_sample, _ = neg_sample.topk(neg_min, dim=-1) |
|
|
neg_topk = min(pos_min, self.top_k) |
|
|
neg_sample = neg_sample.repeat([1, neg_topk]) |
|
|
neg_sample = neg_sample.view([-1, neg_min]) |
|
|
logits_con = torch.cat([pos_sample, neg_sample], dim=-1) |
|
|
logits_con /= self.T |
|
|
return logits_con |
|
|
|
|
|
def init_weights(self): |
|
|
for param_q, param_k in zip(self.contrastive_liner_q.parameters(), self.contrastive_liner_k.parameters()): |
|
|
param_k.data = param_q.data |
|
|
|
|
|
def update_encoder_k(self): |
|
|
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): |
|
|
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) |
|
|
for param_q, param_k in zip(self.contrastive_liner_q.parameters(), self.contrastive_liner_k.parameters()): |
|
|
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) |
|
|
|
|
|
def reshape_dict(self, batch): |
|
|
for k, v in batch.items(): |
|
|
shape = v.shape |
|
|
batch[k] = v.view([-1, shape[-1]]) |
|
|
return batch |
|
|
|
|
|
def l2norm(self, x: torch.Tensor): |
|
|
norm = torch.pow(x, 2).sum(dim=-1, keepdim=True).sqrt() |
|
|
x = torch.div(x, norm) |
|
|
return x |
|
|
|
|
|
def forward_no_multi_v2(self, |
|
|
query, |
|
|
positive_sample=None, |
|
|
negative_sample=None, |
|
|
): |
|
|
labels = query["labels"] |
|
|
labels = labels.view(-1) |
|
|
|
|
|
with torch.no_grad(): |
|
|
self.update_encoder_k() |
|
|
update_sample = self.reshape_dict(positive_sample) |
|
|
bert_output_p = self.encoder_k(**update_sample) |
|
|
update_keys = bert_output_p[1] |
|
|
update_keys = self.contrastive_liner_k(update_keys) |
|
|
update_keys = self.l2norm(update_keys) |
|
|
tmp_labels = labels.unsqueeze(-1) |
|
|
tmp_labels = tmp_labels.repeat([1, self.update_num]) |
|
|
tmp_labels = tmp_labels.view(-1) |
|
|
self._dequeue_and_enqueue(update_keys, tmp_labels) |
|
|
|
|
|
query.pop('labels') |
|
|
|
|
|
bert_output_q = self.encoder_q(**query) |
|
|
q = bert_output_q[1] |
|
|
liner_q = self.contrastive_liner_q(q) |
|
|
liner_q = self.l2norm(liner_q) |
|
|
logits_cls = self.classifier_liner(q) |
|
|
|
|
|
if self.number_labels == 1: |
|
|
loss_fct = MSELoss() |
|
|
loss_cls = loss_fct(logits_cls.view(-1), labels) |
|
|
else: |
|
|
loss_fct = CrossEntropyLoss() |
|
|
loss_cls = loss_fct(logits_cls.view(-1, self.number_labels - 1), labels) |
|
|
|
|
|
logits_con = self.select_pos_neg_sample(liner_q, labels) |
|
|
|
|
|
if logits_con is not None: |
|
|
labels_con = torch.zeros(logits_con.shape[0], dtype=torch.long).cuda() |
|
|
loss_fct = CrossEntropyLoss() |
|
|
loss_con = loss_fct(logits_con, labels_con) |
|
|
|
|
|
loss = loss_con * self.contrastive_rate_in_training + \ |
|
|
loss_cls * (1 - self.contrastive_rate_in_training) |
|
|
else: |
|
|
loss = loss_cls |
|
|
|
|
|
return SequenceClassifierOutput( |
|
|
loss=loss, |
|
|
) |
|
|
|
|
|
def forward(self, |
|
|
query, |
|
|
mode, |
|
|
positive_sample=None, |
|
|
negative_sample=None, |
|
|
): |
|
|
if mode == 'train': |
|
|
return self.forward_no_multi_v2(query=query, positive_sample=positive_sample, |
|
|
negative_sample=negative_sample) |
|
|
elif mode == 'validation': |
|
|
labels = query['labels'] |
|
|
query.pop('labels') |
|
|
seq_embed = self.encoder_q(**query)[1] |
|
|
|
|
|
logits_cls = self.classifier_liner(seq_embed) |
|
|
probs = torch.softmax(logits_cls, dim=1) |
|
|
return torch.argmax(probs, dim=1).tolist(), labels.cpu().numpy().tolist() |
|
|
elif mode == 'test': |
|
|
|
|
|
query.pop('labels') |
|
|
seq_embed = self.encoder_q(**query)[1] |
|
|
logits_cls = self.classifier_liner(seq_embed) |
|
|
|
|
|
probs = torch.softmax(logits_cls, dim=1) |
|
|
return probs, seq_embed |
|
|
else: |
|
|
raise ValueError("undefined mode") |
|
|
|
|
|
|