THU-IAR's picture
Upload 198 files
2d06dcc verified
import torch
import torch.nn.functional as F
import numpy as np
import logging
import os
import time
from sklearn.cluster import KMeans
from sklearn.metrics import confusion_matrix
from tqdm import trange, tqdm
from losses import loss_map
from utils.functions import save_model, restore_model
from torch.utils.data import DataLoader, TensorDataset, RandomSampler
from transformers import BertTokenizer
from torch import nn
from utils.metrics import clustering_score
from utils.functions import view_generator
from losses import loss_map
from .pretrain import PretrainUSNIDManager
from utils.functions import set_seed
class USNIDManager:
def __init__(self, args, data, model, logger_name = 'Discovery'):
pretrain_manager = PretrainUSNIDManager(args, data, model)
set_seed(args.seed)
self.logger = logging.getLogger(logger_name)
loader = data.dataloader
self.train_dataloader, self.eval_dataloader, self.test_dataloader = \
loader.train_outputs['loader'], loader.eval_outputs['loader'], loader.test_outputs['loader']
self.train_input_ids, self.train_input_mask, self.train_segment_ids = \
loader.train_outputs['input_ids'], loader.train_outputs['input_mask'], loader.train_outputs['segment_ids']
self.train_outputs = loader.train_outputs
self.train_labeled_outputs = loader.train_labeled_outputs
self.train_labeled_dataloader = loader.train_labeled_outputs['loader']
self.criterion = loss_map['CrossEntropyLoss']
self.contrast_criterion = loss_map['SupConLoss']
self.tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_model, do_lower_case=True)
self.generator = view_generator(self.tokenizer, args)
self.n_known_cls = data.n_known_cls
if args.pretrain:
self.pretrained_model = pretrain_manager.model
self.set_model_optimizer(args, data, model, pretrain_manager)
self.load_pretrained_model(args, self.pretrained_model)
else:
self.pretrained_model = restore_model(pretrain_manager.model, os.path.join(args.method_output_dir, 'pretrain'))
self.set_model_optimizer(args, data, model, pretrain_manager)
if args.train:
self.load_pretrained_model(args, self.pretrained_model)
else:
self.model = restore_model(self.model, args.model_output_dir)
def set_model_optimizer(self, args, data, model, pretrain_manager):
if args.cluster_num_factor > 1:
args.num_labels = self.num_labels = pretrain_manager.num_labels
else:
args.num_labels = self.num_labels = data.num_labels
self.model = model.set_model(args, data, 'bert', args.freeze_train_bert_parameters)
self.optimizer , self.scheduler = model.set_optimizer(self.model, len(data.dataloader.train_examples), args.train_batch_size, \
args.num_train_epochs, args.lr, args.warmup_proportion)
self.l_optimizer , self.l_scheduler = model.set_optimizer(self.model, len(data.dataloader.train_labeled_examples), args.train_batch_size, \
args.num_train_epochs, args.lr, args.warmup_proportion)
self.device = model.device
def clustering(self, args, init = 'k-means++'):
outputs = self.get_outputs(args, mode = 'train', model = self.model)
feats = outputs['feats']
y_true = outputs['y_true']
labeled_pos = list(np.where(y_true != -1)[0])
labeled_feats = feats[labeled_pos]
labeled_labels = y_true[labeled_pos]
labeled_centers = []
for idx, label in enumerate(np.unique(labeled_labels)):
label_feats = labeled_feats[labeled_labels == label]
labeled_centers.append(np.mean(label_feats, axis = 0))
if init == 'k-means++':
self.logger.info('Initializing centroids with K-means++...')
start = time.time()
km = KMeans(n_clusters = self.num_labels, n_jobs = -1, random_state=args.seed, init = 'k-means++').fit(feats)
km_centroids, assign_labels = km.cluster_centers_, km.labels_
end = time.time()
self.logger.info('K-means++ used %s s', round(end - start, 2))
elif init == 'centers':
start = time.time()
self.centroids
km = KMeans(n_clusters = self.num_labels, n_jobs = -1, random_state=args.seed, init = self.centroids.cpu().numpy()).fit(feats)
km_centroids, assign_labels = km.cluster_centers_, km.labels_
end = time.time()
self.logger.info('K-means used %s s', round(end - start, 2))
self.centroids = torch.tensor(km_centroids).to(self.device)
pseudo_labels = assign_labels.astype(np.long)
return outputs, km_centroids, y_true, assign_labels, pseudo_labels
def train(self, args, data):
self.centroids = None
last_preds = None
for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
self.model.train()
for batch in tqdm(self.train_labeled_dataloader, desc="Training(All)"):
batch = tuple(t.to(self.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
with torch.set_grad_enabled(True):
aug_mlp_outputs_a, aug_logits_a = self.model(input_ids, segment_ids, input_mask)
aug_mlp_outputs_b, aug_logits_b = self.model(input_ids, segment_ids, input_mask)
norm_logits = F.normalize(aug_mlp_outputs_a)
norm_aug_logits = F.normalize(aug_mlp_outputs_b)
contrastive_feats = torch.cat((norm_logits.unsqueeze(1), norm_aug_logits.unsqueeze(1)), dim = 1)
loss_contrast = self.contrast_criterion(contrastive_feats, labels = label_ids, temperature = args.train_temperature, device = self.device)
loss = loss_contrast
self.l_optimizer.zero_grad()
loss.backward()
self.l_optimizer.step()
self.l_scheduler.step()
init_mechanism = 'k-means++' if epoch == 0 else 'centers'
outputs, km_centroids, y_true, assign_labels, pseudo_labels = self.clustering(args, init = init_mechanism)
current_preds = pseudo_labels
delta_label = np.sum(current_preds != last_preds).astype(np.float32) / current_preds.shape[0]
last_preds = np.copy(current_preds)
if epoch > 0:
self.logger.info("***** Epoch: %s *****", str(epoch))
self.logger.info('Training Loss: %f', np.round(tr_loss, 5))
self.logger.info('Delta Label: %f', delta_label)
if delta_label < args.tol:
self.logger.info('delta_label %s < %f', delta_label, args.tol)
self.logger.info('Reached tolerance threshold. Stop training.')
break
pseudo_train_dataloader = self.get_augment_dataloader(args, self.train_outputs, pseudo_labels)
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
self.model.train()
for batch in tqdm(pseudo_train_dataloader, desc="Training(All)"):
batch = tuple(t.to(self.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
with torch.set_grad_enabled(True):
input_ids_a, input_ids_b = self.batch_chunk(input_ids)
input_mask_a, input_mask_b = self.batch_chunk(input_mask)
segment_ids_a, segment_ids_b = self.batch_chunk(segment_ids)
label_ids = torch.chunk(input=label_ids, chunks=2, dim=1)[0][:, 0]
aug_mlp_outputs_a, aug_logits_a = self.model(input_ids_a, segment_ids_a, input_mask_a)
aug_mlp_outputs_b, aug_logits_b = self.model(input_ids_b, segment_ids_b, input_mask_b)
norm_logits = F.normalize(aug_mlp_outputs_a)
norm_aug_logits = F.normalize(aug_mlp_outputs_b)
loss_ce = 0.5 * (self.criterion(aug_logits_a, label_ids) + self.criterion(aug_logits_b, label_ids))
contrastive_feats = torch.cat((norm_logits.unsqueeze(1), norm_aug_logits.unsqueeze(1)), dim = 1)
loss_contrast = self.contrast_criterion(contrastive_feats, labels = label_ids, temperature = args.train_temperature, device = self.device)
loss = loss_contrast + loss_ce
self.optimizer.zero_grad()
loss.backward()
if args.grad_clip != -1.0:
nn.utils.clip_grad_value_([param for param in self.model.parameters() if param.requires_grad], args.grad_clip)
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
self.optimizer.step()
self.scheduler.step()
tr_loss = tr_loss / nb_tr_steps
if args.save_model:
save_model(self.model, args.model_output_dir)
def test(self, args, data):
outputs = self.get_outputs(args, mode = 'test', model = self.model)
feats = outputs['feats']
y_true = outputs['y_true']
if args.cluster_num_factor > 1:
test_results['estimate_k'] = args.num_labels
km = KMeans(n_clusters = self.num_labels, n_jobs = -1, random_state=args.seed, init = self.centroids.cpu().numpy()).fit(feats)
y_pred = km.labels_
test_results = clustering_score(y_true, y_pred)
cm = confusion_matrix(y_true, y_pred)
self.logger.info
self.logger.info("***** Test: Confusion Matrix *****")
self.logger.info("%s", str(cm))
self.logger.info("***** Test results *****")
for key in sorted(test_results.keys()):
self.logger.info(" %s = %s", key, str(test_results[key]))
test_results['y_true'] = y_true
test_results['y_pred'] = y_pred
return test_results
def get_outputs(self, args, mode, model):
if mode == 'eval':
dataloader = self.eval_dataloader
elif mode == 'test':
dataloader = self.test_dataloader
elif mode == 'train':
dataloader = self.train_dataloader
model.eval()
total_labels = torch.empty(0,dtype=torch.long).to(self.device)
total_preds = torch.empty(0,dtype=torch.long).to(self.device)
total_features = torch.empty((0,args.feat_dim)).to(self.device)
total_logits = torch.empty((0, self.num_labels)).to(self.device)
for batch in tqdm(dataloader, desc="Iteration"):
batch = tuple(t.to(self.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
with torch.set_grad_enabled(False):
pooled_output, logits = model(input_ids, segment_ids, input_mask, feature_ext = True)
total_labels = torch.cat((total_labels,label_ids))
total_features = torch.cat((total_features, pooled_output))
total_logits = torch.cat((total_logits, logits))
feats = total_features.cpu().numpy()
y_true = total_labels.cpu().numpy()
total_probs = F.softmax(total_logits.detach(), dim=1)
total_maxprobs, total_preds = total_probs.max(dim = 1)
y_pred = total_preds.cpu().numpy()
y_logits = total_logits.cpu().numpy()
outputs = {
'y_true': y_true,
'y_pred': y_pred,
'logits': y_logits,
'feats': feats
}
return outputs
def load_pretrained_model(self, args, pretrained_model):
pretrained_dict = pretrained_model.state_dict()
classifier_params = ['mlp_head.bias','mlp_head.0.bias', 'classifier.weight', 'classifier.bias', 'mlp_head.0.weight', 'mlp_head.weight']
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k not in classifier_params}
self.model.load_state_dict(pretrained_dict, strict=False)
def batch_chunk(self, x):
x1, x2 = torch.chunk(input=x, chunks=2, dim=1)
x1, x2 = x1.squeeze(1), x2.squeeze(1)
return x1, x2
def get_augment_dataloader(self, args, train_outputs, pseudo_labels = None):
input_ids = train_outputs['input_ids']
input_mask = train_outputs['input_mask']
segment_ids = train_outputs['segment_ids']
if pseudo_labels is None:
pseudo_labels = train_outputs['label_ids']
input_ids_a, input_mask_a = self.generator.random_token_erase(input_ids, input_mask)
input_ids_b, input_mask_b = self.generator.random_token_erase(input_ids, input_mask)
train_input_ids = torch.cat(([input_ids_a.unsqueeze(1), input_ids_b.unsqueeze(1)]), dim = 1)
train_input_mask = torch.cat(([input_mask_a.unsqueeze(1), input_mask_a.unsqueeze(1)]), dim = 1)
train_segment_ids = torch.cat(([segment_ids.unsqueeze(1), segment_ids.unsqueeze(1)]), dim = 1)
train_label_ids = torch.tensor(pseudo_labels).unsqueeze(1)
train_label_ids = torch.cat(([train_label_ids, train_label_ids]), dim = 1)
train_data = TensorDataset(train_input_ids, train_input_mask, train_segment_ids, train_label_ids)
sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler = sampler, batch_size = args.train_batch_size)
return train_dataloader