|
|
import torch |
|
|
from torch import nn |
|
|
import numpy as np |
|
|
|
|
|
def l2_norm(input,axis=1): |
|
|
norm = torch.norm(input, 2, axis, True) |
|
|
output = torch.div(input, norm) |
|
|
return output |
|
|
|
|
|
class L2_normalization(nn.Module): |
|
|
def forward(self, input): |
|
|
return l2_norm(input) |
|
|
|
|
|
def freeze_bert_parameters(model): |
|
|
for name, param in model.bert.named_parameters(): |
|
|
param.requires_grad = False |
|
|
if "encoder.layer.11" in name or "pooler" in name: |
|
|
param.requires_grad = True |
|
|
return model |
|
|
|
|
|
def freeze_bert_parameters_KCL(model): |
|
|
for name, param in model.encoder_q.named_parameters(): |
|
|
param.requires_grad = False |
|
|
if "encoder.layer.11" in name or "pooler" in name: |
|
|
param.requires_grad = True |
|
|
for name, param in model.encoder_k.named_parameters(): |
|
|
param.requires_grad = False |
|
|
if "encoder.layer.11" in name or "pooler" in name: |
|
|
param.requires_grad = True |
|
|
return model |
|
|
|
|
|
class ConvexSampler(nn.Module): |
|
|
def __init__(self, args): |
|
|
super(ConvexSampler, self).__init__() |
|
|
self.multiple_convex = args.multiple_convex |
|
|
self.multiple_convex_eval = args.multiple_convex_eval |
|
|
self.unseen_label_id = args.unseen_label_id |
|
|
self.device = args.device |
|
|
self.batch_size = args.train_batch_size |
|
|
self.oos_num = args.train_batch_size |
|
|
self.feat_dim = args.feat_dim |
|
|
|
|
|
def forward(self, z, label_ids, mode=None): |
|
|
num_convex = self.batch_size * self.multiple_convex |
|
|
num_convex_eval = self.batch_size * self.multiple_convex_eval |
|
|
convex_list = [] |
|
|
if mode =='train': |
|
|
if label_ids.size(0)>2: |
|
|
while len(convex_list) < num_convex: |
|
|
cdt = np.random.choice(label_ids.size(0), 2, replace=False) |
|
|
|
|
|
if label_ids[cdt[0]] != label_ids[cdt[1]]: |
|
|
s = np.random.uniform(0, 1, 1) |
|
|
convex_list.append(s[0] * z[cdt[0]] + (1 - s[0]) * z[cdt[1]]) |
|
|
convex_samples = torch.cat(convex_list, dim=0).view(num_convex, -1) |
|
|
z = torch.cat((z, convex_samples), dim=0) |
|
|
label_ids = torch.cat((label_ids, torch.tensor([self.unseen_label_id] * num_convex).to(self.device)), dim=0) |
|
|
elif mode == 'eval': |
|
|
if label_ids.size(0) > 2: |
|
|
val_num = num_convex_eval |
|
|
while len(convex_list) < val_num: |
|
|
cdt = np.random.choice(label_ids.size(0), 2, replace=False) |
|
|
if label_ids[cdt[0]] != label_ids[cdt[1]]: |
|
|
s = np.random.uniform(0, 1, 1) |
|
|
convex_list.append(s[0] * z[cdt[0]] + (1 - s[0]) * z[cdt[1]]) |
|
|
convex_samples = torch.cat(convex_list, dim=0).view(val_num, -1) |
|
|
z = torch.cat((z, convex_samples), dim=0) |
|
|
label_ids = torch.cat((label_ids, torch.tensor([self.unseen_label_id] * val_num).to(self.device)), dim=0) |
|
|
return z, label_ids |
|
|
|
|
|
def pair_cosine_similarity(x, x_adv, eps=1e-8): |
|
|
n = x.norm(p=2, dim=1, keepdim=True) |
|
|
n_adv = x_adv.norm(p=2, dim=1, keepdim=True) |
|
|
return (x @ x.t()) / (n * n.t()).clamp(min=eps), (x_adv @ x_adv.t()) / (n_adv * n_adv.t()).clamp(min=eps), (x @ x_adv.t()) / (n * n_adv.t()).clamp(min=eps) |
|
|
|
|
|
def nt_xent(x, x_adv, mask, cuda=True, t=0.1): |
|
|
x, x_adv, x_c = pair_cosine_similarity(x, x_adv) |
|
|
x = torch.exp(x / t) |
|
|
x_adv = torch.exp(x_adv / t) |
|
|
x_c = torch.exp(x_c / t) |
|
|
mask_count = mask.sum(1) |
|
|
mask_reverse = (~(mask.bool())).long() |
|
|
if cuda: |
|
|
dis = (x * (mask - torch.eye(x.size(0)).long().cuda()) + x_c * mask) / (x.sum(1) + x_c.sum(1) - torch.exp(torch.tensor(1 / t))) + mask_reverse |
|
|
dis_adv = (x_adv * (mask - torch.eye(x.size(0)).long().cuda()) + x_c.T * mask) / (x_adv.sum(1) + x_c.sum(0) - torch.exp(torch.tensor(1 / t))) + mask_reverse |
|
|
else: |
|
|
dis = (x * (mask - torch.eye(x.size(0)).long()) + x_c * mask) / (x.sum(1) + x_c.sum(1) - torch.exp(torch.tensor(1 / t))) + mask_reverse |
|
|
dis_adv = (x_adv * (mask - torch.eye(x.size(0)).long()) + x_c.T * mask) / (x_adv.sum(1) + x_c.sum(0) - torch.exp(torch.tensor(1 / t))) + mask_reverse |
|
|
loss = (torch.log(dis).sum(1) + torch.log(dis_adv).sum(1)) / mask_count |
|
|
return -loss.mean() |
|
|
|