THU-IAR's picture
Upload 198 files
2d06dcc verified
import torch
import math
import torch.nn.functional as F
import numpy as np
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn.parameter import Parameter
from transformers import BertPreTrainedModel, BertModel, BertForMaskedLM, AutoConfig
from transformers.modeling_outputs import SequenceClassifierOutput
from .utils import ConvexSampler
activation_map = {'relu': nn.ReLU(), 'tanh': nn.Tanh()}
class BERT_DOC(BertPreTrainedModel):
def __init__(self, config, args):
super(BERT_DOC, self).__init__(config)
self.num_labels = args.num_labels
self.bert = BertModel(config)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = activation_map[args.activation]
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, args.num_labels)
self.init_weights()
def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None,
feature_ext=False, mode=None, loss_fct=None, centroids = None):
outputs = self.bert(
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True)
encoded_layer_12 = outputs.hidden_states
pooled_output = self.dense(encoded_layer_12[-1].mean(dim=1))
pooled_output = self.dropout(pooled_output)
pooled_output = self.activation(pooled_output)
logits = self.classifier(pooled_output)
logits = self.dropout(logits)
sigmoid = nn.Sigmoid()
logits = sigmoid(logits)
if feature_ext:
return pooled_output
else:
if mode == 'train':
target = F.one_hot(labels, num_classes = self.num_labels)
loss_bce = loss_fct(logits, target.float())
return loss_bce
else:
return pooled_output, logits
class BERT(BertPreTrainedModel):
def __init__(self, config, args):
super(BERT, self).__init__(config)
self.num_labels = args.num_labels
self.bert = BertModel(config)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = activation_map[args.activation]
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, args.num_labels)
self.init_weights()
def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None,
feature_ext=False, mode=None, loss_fct=None, centroids = None):
outputs = self.bert(
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True)
encoded_layer_12 = outputs.hidden_states
pooled_output = self.dense(encoded_layer_12[-1].mean(dim=1))
pooled_output = self.activation(pooled_output)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
if feature_ext:
return pooled_output
else:
if mode == 'train':
loss_ce = loss_fct(logits, labels)
return loss_ce
else:
return pooled_output, logits
class BERT_Norm(BertPreTrainedModel):
def __init__(self, config, args):
super(BERT_Norm, self).__init__(config)
self.num_labels = args.num_labels
self.bert = BertModel(config)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.init_weights()
self.weight = Parameter(torch.FloatTensor(args.num_labels, args.feat_dim).to(args.device))
nn.init.xavier_uniform_(self.weight)
def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None,
feature_ext=False, mode=None, loss_fct=None, device = None, head = None):
outputs = self.bert(
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True)
encoded_layer_12 = outputs.hidden_states
pooled_output = encoded_layer_12[-1].mean(dim=1)
pooled_output = self.dropout(pooled_output)
pooled_output = F.normalize(pooled_output)
logits = F.linear(pooled_output, F.normalize(self.weight))
logits = F.softmax(logits, dim = 1)
if feature_ext:
return pooled_output
else:
if mode == 'train':
loss = loss_fct(logits, labels)
return loss
else:
return pooled_output, logits
class BERT_K_1_way(BertPreTrainedModel):
def __init__(self, config, args):
super(BERT_K_1_way, self).__init__(config)
self.num_labels = args.num_labels
self.bert = BertModel(config)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = activation_map[args.activation]
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.sampler = ConvexSampler(args)
self.classifier = nn.Linear(config.hidden_size, self.num_labels + 1)
self.t = args.temp
self.init_weights()
def forward(self, input_ids = None, token_type_ids = None, attention_mask=None , labels = None,
feature_ext = False, mode = None, loss_fct = None):
outputs = self.bert(
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True)
encoded_layer_12 = outputs.hidden_states
pooled_output = self.dense(encoded_layer_12[-1].mean(dim=1))
if mode is not 'test':
pooled_output, labels = self.sampler(pooled_output, labels, mode=mode)
pooled_output = self.activation(pooled_output)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
if feature_ext:
return pooled_output
else:
if mode == 'train':
loss = loss_fct(torch.div(logits, self.t), labels)
return loss
else:
return pooled_output, logits, labels
class BERT_SEG(BertPreTrainedModel):
def __init__(self, config, args):
super(BERT_SEG, self).__init__(config)
self.num_labels = args.num_labels
self.bert = BertModel(config)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = activation_map[args.activation]
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.init_weights()
self.alpha = args.alpha
self.lambda_ = args.lambda_
self.means = nn.Parameter(torch.randn(self.num_labels, args.feat_dim).cuda())
nn.init.xavier_uniform_(self.means, gain=math.sqrt(2.0))
def forward(self, input_ids = None, token_type_ids = None, attention_mask=None , labels = None,
feature_ext = False, mode = None, device=None, p_y = None, class_emb=None, loss_fct=None):
outputs = self.bert(
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True)
encoded_layer_12 = outputs.hidden_states
pooled_output = self.dense(encoded_layer_12[-1].mean(dim=1))
pooled_output = self.activation(pooled_output)
pooled_output = self.dropout(pooled_output)
if feature_ext:
return pooled_output
else:
batch_size = pooled_output.shape[0]
XY = torch.matmul(pooled_output, torch.transpose(self.means, 0, 1))
XX = torch.sum(pooled_output ** 2, dim=1, keepdim=True)
YY = torch.sum(torch.transpose(self.means, 0, 1)**2, dim=0, keepdim=True)
neg_sqr_dist = - 0.5 * (XX - 2.0 * XY + YY)
# with p_y
########################################
p_y = p_y.expand_as(neg_sqr_dist).to(device) # [bsz, n_c_seen]
dist_exp = torch.exp(neg_sqr_dist)
dist_exp_py = p_y.mul(dist_exp)
dist_exp_sum = torch.sum(dist_exp_py, dim=1, keepdim=True) # [bsz, n_c_seen] -> [bsz, 1]
logits = dist_exp_py / dist_exp_sum # [bsz, n_c, seen]
if mode == 'train':
labels_reshped = labels.view(labels.size()[0], -1) # [bsz] -> [bsz, 1]
ALPHA = torch.zeros(batch_size, self.num_labels).to(device).scatter_(1, labels_reshped, self.alpha) # margin
K = ALPHA + torch.ones([batch_size, self.num_labels]).to(device)
#######################################
dist_margin = torch.mul(neg_sqr_dist, K)
dist_margin_exp = torch.exp(dist_margin)
dist_margin_exp_py = p_y.mul(dist_margin_exp)
dist_exp_sum_margin = torch.sum(dist_margin_exp_py, dim=1, keepdim=True)
likelihood = dist_margin_exp_py / dist_exp_sum_margin
loss_ce = - likelihood.log().sum() / batch_size
#######################################
means = self.means if class_emb is None else class_emb
means_batch = torch.index_select(means, dim=0, index=labels)
loss_gen = (torch.sum((pooled_output - means_batch)**2) / 2) * (1. / batch_size)
########################################
loss = loss_ce + self.lambda_ * loss_gen
return loss
else:
return pooled_output, logits
class CosNorm_Classifier(nn.Module):
def __init__(self, in_dims, out_dims, scale=64, device = None):
super(CosNorm_Classifier, self).__init__()
self.in_dims = in_dims
self.out_dims = out_dims
self.scale = scale
self.weight = Parameter(torch.Tensor(out_dims, in_dims).to(device))
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
def forward(self, input, *args):
norm_x = torch.norm(input, 2, 1, keepdim=True)
ex = (norm_x / (1 + norm_x)) * (input / norm_x)
ew = self.weight / torch.norm(self.weight, 2, 1, keepdim=True)
return torch.mm(self.scale * ex, ew.t())
class BERT_Disaware(BertPreTrainedModel):
def __init__(self, config, args):
super(BERT_Disaware, self).__init__(config)
self.num_labels = args.num_labels
self.bert = BertModel(config)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.ReLU()
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.init_weights()
self.cosnorm_classifier = CosNorm_Classifier(
config.hidden_size, args.num_labels, args.scale, args.device)
def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None,
feature_ext=False, mode=None, loss_fct=None, centroids=None, dist_infos = None):
outputs = self.bert(
input_ids, token_type_ids, attention_mask, output_hidden_states=True)
encoded_layer_12 = outputs.hidden_states
pooled_output = self.dense(encoded_layer_12[-1].mean(dim=1))
pooled_output = self.activation(pooled_output)
pooled_output = self.dropout(pooled_output)
x = pooled_output
if feature_ext:
return pooled_output
else:
feat_size = x.shape[1]
batch_size = x.shape[0]
f_expand = x.unsqueeze(1).expand(-1, self.num_labels, -1)
centroids_expand = centroids.unsqueeze(0).expand(batch_size, -1, -1)
dist_cur = torch.norm(f_expand - centroids_expand, 2, 2)
values_nn, labels_nn = torch.sort(dist_cur, 1)
nearest_centers = centroids[labels_nn[:, 0]]
dist_denominator = torch.norm(x - nearest_centers, 2, 1)
second_nearest_centers = centroids[labels_nn[:, 1]]
dist_numerator = torch.norm(x - second_nearest_centers, 2, 1)
dist_info = dist_numerator - dist_denominator
dist_info = torch.exp(dist_info)
scalar = dist_info
reachability = scalar.unsqueeze(1).expand(-1, feat_size)
x = reachability * pooled_output
logits = self.cosnorm_classifier(x)
if mode == 'train':
loss = loss_fct(logits, labels)
return loss
elif mode == 'eval':
return pooled_output, logits
class BERT_MDF_Pretrain(nn.Module):
def __init__(self, args):
super(BERT_MDF_Pretrain, self).__init__()
self.num_labels = args.num_labels
self.bert = BertForMaskedLM.from_pretrained(args.pretrained_bert_model)
self.dropout = nn.Dropout(0.1) #0.1
self.classifier = nn.Linear(args.feat_dim, args.num_labels)
def forward(self, X):
outputs = self.bert(**X, output_hidden_states=True)
CLSEmbedding = outputs.hidden_states[-1][:,0]
CLSEmbedding = self.dropout(CLSEmbedding)
logits = self.classifier(CLSEmbedding)
output_dir = {"logits": logits}
output_dir["hidden_states"] = outputs.hidden_states[-1][:, 0]
return output_dir
def mlmForward(self, X, Y = None):
outputs = self.bert(**X, labels = Y)
return outputs.loss
def loss_ce(self, logits, Y):
loss = nn.CrossEntropyLoss()
output = loss(logits, Y)
return output
class BERT_MDF(BertPreTrainedModel):
def __init__(self, config, args):
super(BERT_MDF, self).__init__(config)
self.num_labels = args.num_labels
self.bert = BertModel(config)
self.dropout = nn.Dropout(0.1) #0.1
self.classifier = nn.Linear(args.feat_dim, 2)
self.init_weights()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
output_hidden_states=True
)
# Complains if input_embeds is kept
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
outputs = (logits,) + outputs[
2:
] # add hidden states and attention if they are here
return outputs # (loss), logits, (hidden_states), (attentions)
class BertClassificationHead(nn.Module):
def __init__(self, config):
super(BertClassificationHead, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels-1)
def forward(self, feature):
x = self.dropout(feature)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class BertContrastiveHead(nn.Module):
def __init__(self, config):
super(BertContrastiveHead, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, feature):
x = self.dropout(feature)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class BERT_KNNCL(nn.Module):
def __init__(self, args):
super(BERT_KNNCL, self).__init__()
self.number_labels = args.anum_labels
config = AutoConfig.from_pretrained(
args.bert_model ,
num_labels=self.number_labels,
)
self.encoder_q = BertModel.from_pretrained(args.bert_model, config=config)
self.encoder_k = BertModel.from_pretrained(args.bert_model, config=config)
self.classifier_liner = BertClassificationHead(config)
self.contrastive_liner_q = BertContrastiveHead(config)
self.contrastive_liner_k = BertContrastiveHead(config)
self.m = 0.999
self.T = args.temperature
self.init_weights() # Exec
self.contrastive_rate_in_training = args.contrastive_rate_in_training
# create the label_queue and feature_queue
self.K = args.queue_size # 7500
self.register_buffer("label_queue", torch.randint(0, self.number_labels, [self.K])) # Tensor:(7500,)
self.register_buffer("feature_queue", torch.randn(self.K, config.hidden_size)) # Tensor:(7500, 768)
self.feature_queue = torch.nn.functional.normalize(self.feature_queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) # Tensor(1,)
self.top_k = args.top_k # 25
self.update_num = args.positive_num # 3
# optional and delete can improve the performance indicated
# by some experiment
# params_to_train = ["layer." + str(i) for i in range(0, 12)]
# for name, param in self.encoder_q.named_parameters():
# param.requires_grad_(False)
# for term in params_to_train:
# if term in name:
# param.requires_grad_(True)
def _dequeue_and_enqueue(self, keys, label):
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
if ptr + batch_size > self.K:
batch_size = self.K - ptr
keys = keys[: batch_size]
label = label[: batch_size]
# replace the keys at ptr (dequeue ans enqueue)
self.feature_queue[ptr: ptr + batch_size, :] = keys
self.label_queue[ptr: ptr + batch_size] = label
ptr = (ptr + batch_size) % self.K
self.queue_ptr[0] = ptr
def select_pos_neg_sample(self, liner_q, label_q):
label_queue = self.label_queue.clone().detach() # K
feature_queue = self.feature_queue.clone().detach() # K * hidden_size
# 1. expand label_queue and feature_queue to batch_size * K
batch_size = label_q.shape[0]
tmp_label_queue = label_queue.repeat([batch_size, 1])
tmp_feature_queue = feature_queue.unsqueeze(0)
tmp_feature_queue = tmp_feature_queue.repeat([batch_size, 1, 1]) # batch_size * K * hidden_size
# 2.caluate sim
cos_sim = torch.einsum('nc,nkc->nk', [liner_q, tmp_feature_queue])
# 3. get index of postive and neigative
tmp_label = label_q.unsqueeze(1)
tmp_label = tmp_label.repeat([1, self.K])
pos_mask_index = torch.eq(tmp_label_queue, tmp_label)
neg_mask_index = ~ pos_mask_index
# 4.another option
feature_value = cos_sim.masked_select(neg_mask_index)
neg_sample = torch.full_like(cos_sim, -np.inf).cuda()
neg_sample = neg_sample.masked_scatter(neg_mask_index, feature_value)
# 5.topk
pos_mask_index = pos_mask_index.int()
pos_number = pos_mask_index.sum(dim=-1)
pos_min = pos_number.min()
if pos_min == 0:
return None
pos_sample, _ = cos_sim.topk(pos_min, dim=-1)
pos_sample_top_k = pos_sample[:, 0:self.top_k] # self.topk = 25
pos_sample = pos_sample_top_k
pos_sample = pos_sample.contiguous().view([-1, 1])
neg_mask_index = neg_mask_index.int()
neg_number = neg_mask_index.sum(dim=-1)
neg_min = neg_number.min()
if neg_min == 0:
return None
neg_sample, _ = neg_sample.topk(neg_min, dim=-1)
neg_topk = min(pos_min, self.top_k)
neg_sample = neg_sample.repeat([1, neg_topk])
neg_sample = neg_sample.view([-1, neg_min])
logits_con = torch.cat([pos_sample, neg_sample], dim=-1)
logits_con /= self.T
return logits_con
def init_weights(self):
for param_q, param_k in zip(self.contrastive_liner_q.parameters(), self.contrastive_liner_k.parameters()):
param_k.data = param_q.data
def update_encoder_k(self):
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
for param_q, param_k in zip(self.contrastive_liner_q.parameters(), self.contrastive_liner_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
def reshape_dict(self, batch):
for k, v in batch.items():
shape = v.shape
batch[k] = v.view([-1, shape[-1]])
return batch
def l2norm(self, x: torch.Tensor):
norm = torch.pow(x, 2).sum(dim=-1, keepdim=True).sqrt()
x = torch.div(x, norm)
return x
def forward_no_multi_v2(self,
query,
positive_sample=None,
negative_sample=None,
):
labels = query["labels"]
labels = labels.view(-1)
with torch.no_grad():
self.update_encoder_k()
update_sample = self.reshape_dict(positive_sample)
bert_output_p = self.encoder_k(**update_sample)
update_keys = bert_output_p[1]
update_keys = self.contrastive_liner_k(update_keys)
update_keys = self.l2norm(update_keys)
tmp_labels = labels.unsqueeze(-1)
tmp_labels = tmp_labels.repeat([1, self.update_num])
tmp_labels = tmp_labels.view(-1)
self._dequeue_and_enqueue(update_keys, tmp_labels)
query.pop('labels')
bert_output_q = self.encoder_q(**query)
q = bert_output_q[1]
liner_q = self.contrastive_liner_q(q)
liner_q = self.l2norm(liner_q)
logits_cls = self.classifier_liner(q)
if self.number_labels == 1:
loss_fct = MSELoss()
loss_cls = loss_fct(logits_cls.view(-1), labels)
else:
loss_fct = CrossEntropyLoss()
loss_cls = loss_fct(logits_cls.view(-1, self.number_labels - 1), labels)
logits_con = self.select_pos_neg_sample(liner_q, labels)
if logits_con is not None:
labels_con = torch.zeros(logits_con.shape[0], dtype=torch.long).cuda()
loss_fct = CrossEntropyLoss()
loss_con = loss_fct(logits_con, labels_con)
loss = loss_con * self.contrastive_rate_in_training + \
loss_cls * (1 - self.contrastive_rate_in_training)
else:
loss = loss_cls
return SequenceClassifierOutput(
loss=loss,
)
def forward(self,
query, # batch_size * max_length
mode,
positive_sample=None, # batch_size * max_length
negative_sample=None, # batch_size * sample_num * max_length
):
if mode == 'train':
return self.forward_no_multi_v2(query=query, positive_sample=positive_sample,
negative_sample=negative_sample)
elif mode == 'validation':
labels = query['labels']
query.pop('labels')
seq_embed = self.encoder_q(**query)[1]
logits_cls = self.classifier_liner(seq_embed)
probs = torch.softmax(logits_cls, dim=1)
return torch.argmax(probs, dim=1).tolist(), labels.cpu().numpy().tolist()
elif mode == 'test':
query.pop('labels')
seq_embed = self.encoder_q(**query)[1]
logits_cls = self.classifier_liner(seq_embed)
probs = torch.softmax(logits_cls, dim=1)
return probs, seq_embed
else:
raise ValueError("undefined mode")