THU-IAR's picture
Upload 198 files
2d06dcc verified
import torch
from torch import nn
import numpy as np
def l2_norm(input,axis=1):
norm = torch.norm(input, 2, axis, True)
output = torch.div(input, norm)
return output
class L2_normalization(nn.Module):
def forward(self, input):
return l2_norm(input)
def freeze_bert_parameters(model):
for name, param in model.bert.named_parameters():
param.requires_grad = False
if "encoder.layer.11" in name or "pooler" in name:
param.requires_grad = True
return model
def freeze_bert_parameters_KCL(model):
for name, param in model.encoder_q.named_parameters():
param.requires_grad = False
if "encoder.layer.11" in name or "pooler" in name:
param.requires_grad = True
for name, param in model.encoder_k.named_parameters():
param.requires_grad = False
if "encoder.layer.11" in name or "pooler" in name:
param.requires_grad = True
return model
class ConvexSampler(nn.Module):
def __init__(self, args):
super(ConvexSampler, self).__init__()
self.multiple_convex = args.multiple_convex
self.multiple_convex_eval = args.multiple_convex_eval
self.unseen_label_id = args.unseen_label_id
self.device = args.device
self.batch_size = args.train_batch_size
self.oos_num = args.train_batch_size
self.feat_dim = args.feat_dim
def forward(self, z, label_ids, mode=None):
num_convex = self.batch_size * self.multiple_convex
num_convex_eval = self.batch_size * self.multiple_convex_eval
convex_list = []
if mode =='train':
if label_ids.size(0)>2:
while len(convex_list) < num_convex:
cdt = np.random.choice(label_ids.size(0), 2, replace=False)
# cdt = np.random.choice(label_ids.size(0) - self.oos_num, 2, replace=False)
if label_ids[cdt[0]] != label_ids[cdt[1]]:
s = np.random.uniform(0, 1, 1)
convex_list.append(s[0] * z[cdt[0]] + (1 - s[0]) * z[cdt[1]])
convex_samples = torch.cat(convex_list, dim=0).view(num_convex, -1)
z = torch.cat((z, convex_samples), dim=0)
label_ids = torch.cat((label_ids, torch.tensor([self.unseen_label_id] * num_convex).to(self.device)), dim=0)
elif mode == 'eval':
if label_ids.size(0) > 2:
val_num = num_convex_eval
while len(convex_list) < val_num:
cdt = np.random.choice(label_ids.size(0), 2, replace=False)
if label_ids[cdt[0]] != label_ids[cdt[1]]:
s = np.random.uniform(0, 1, 1)
convex_list.append(s[0] * z[cdt[0]] + (1 - s[0]) * z[cdt[1]])
convex_samples = torch.cat(convex_list, dim=0).view(val_num, -1)
z = torch.cat((z, convex_samples), dim=0)
label_ids = torch.cat((label_ids, torch.tensor([self.unseen_label_id] * val_num).to(self.device)), dim=0)
return z, label_ids
def pair_cosine_similarity(x, x_adv, eps=1e-8):
n = x.norm(p=2, dim=1, keepdim=True)
n_adv = x_adv.norm(p=2, dim=1, keepdim=True)
return (x @ x.t()) / (n * n.t()).clamp(min=eps), (x_adv @ x_adv.t()) / (n_adv * n_adv.t()).clamp(min=eps), (x @ x_adv.t()) / (n * n_adv.t()).clamp(min=eps)
def nt_xent(x, x_adv, mask, cuda=True, t=0.1):
x, x_adv, x_c = pair_cosine_similarity(x, x_adv)
x = torch.exp(x / t)
x_adv = torch.exp(x_adv / t)
x_c = torch.exp(x_c / t)
mask_count = mask.sum(1)
mask_reverse = (~(mask.bool())).long()
if cuda:
dis = (x * (mask - torch.eye(x.size(0)).long().cuda()) + x_c * mask) / (x.sum(1) + x_c.sum(1) - torch.exp(torch.tensor(1 / t))) + mask_reverse
dis_adv = (x_adv * (mask - torch.eye(x.size(0)).long().cuda()) + x_c.T * mask) / (x_adv.sum(1) + x_c.sum(0) - torch.exp(torch.tensor(1 / t))) + mask_reverse
else:
dis = (x * (mask - torch.eye(x.size(0)).long()) + x_c * mask) / (x.sum(1) + x_c.sum(1) - torch.exp(torch.tensor(1 / t))) + mask_reverse
dis_adv = (x_adv * (mask - torch.eye(x.size(0)).long()) + x_c.T * mask) / (x_adv.sum(1) + x_c.sum(0) - torch.exp(torch.tensor(1 / t))) + mask_reverse
loss = (torch.log(dis).sum(1) + torch.log(dis_adv).sum(1)) / mask_count
return -loss.mean()