import torch import torch.nn.functional as F import numpy as np import logging import os import time from sklearn.cluster import KMeans from sklearn.metrics import confusion_matrix from tqdm import trange, tqdm from losses import loss_map from utils.functions import save_model, restore_model from torch.utils.data import DataLoader, TensorDataset, RandomSampler from transformers import BertTokenizer from torch import nn from utils.metrics import clustering_score from utils.functions import view_generator from losses import loss_map from .pretrain import PretrainUSNIDManager from utils.functions import set_seed class USNIDManager: def __init__(self, args, data, model, logger_name = 'Discovery'): pretrain_manager = PretrainUSNIDManager(args, data, model) set_seed(args.seed) self.logger = logging.getLogger(logger_name) loader = data.dataloader self.train_dataloader, self.eval_dataloader, self.test_dataloader = \ loader.train_outputs['loader'], loader.eval_outputs['loader'], loader.test_outputs['loader'] self.train_input_ids, self.train_input_mask, self.train_segment_ids = \ loader.train_outputs['input_ids'], loader.train_outputs['input_mask'], loader.train_outputs['segment_ids'] self.train_outputs = loader.train_outputs self.train_labeled_outputs = loader.train_labeled_outputs self.train_labeled_dataloader = loader.train_labeled_outputs['loader'] self.criterion = loss_map['CrossEntropyLoss'] self.contrast_criterion = loss_map['SupConLoss'] self.tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_model, do_lower_case=True) self.generator = view_generator(self.tokenizer, args) self.n_known_cls = data.n_known_cls if args.pretrain: self.pretrained_model = pretrain_manager.model self.set_model_optimizer(args, data, model, pretrain_manager) self.load_pretrained_model(args, self.pretrained_model) else: self.pretrained_model = restore_model(pretrain_manager.model, os.path.join(args.method_output_dir, 'pretrain')) self.set_model_optimizer(args, data, model, pretrain_manager) if args.train: self.load_pretrained_model(args, self.pretrained_model) else: self.model = restore_model(self.model, args.model_output_dir) def set_model_optimizer(self, args, data, model, pretrain_manager): if args.cluster_num_factor > 1: args.num_labels = self.num_labels = pretrain_manager.num_labels else: args.num_labels = self.num_labels = data.num_labels self.model = model.set_model(args, data, 'bert', args.freeze_train_bert_parameters) self.optimizer , self.scheduler = model.set_optimizer(self.model, len(data.dataloader.train_examples), args.train_batch_size, \ args.num_train_epochs, args.lr, args.warmup_proportion) self.l_optimizer , self.l_scheduler = model.set_optimizer(self.model, len(data.dataloader.train_labeled_examples), args.train_batch_size, \ args.num_train_epochs, args.lr, args.warmup_proportion) self.device = model.device def clustering(self, args, init = 'k-means++'): outputs = self.get_outputs(args, mode = 'train', model = self.model) feats = outputs['feats'] y_true = outputs['y_true'] labeled_pos = list(np.where(y_true != -1)[0]) labeled_feats = feats[labeled_pos] labeled_labels = y_true[labeled_pos] labeled_centers = [] for idx, label in enumerate(np.unique(labeled_labels)): label_feats = labeled_feats[labeled_labels == label] labeled_centers.append(np.mean(label_feats, axis = 0)) if init == 'k-means++': self.logger.info('Initializing centroids with K-means++...') start = time.time() km = KMeans(n_clusters = self.num_labels, n_jobs = -1, random_state=args.seed, init = 'k-means++').fit(feats) km_centroids, assign_labels = km.cluster_centers_, km.labels_ end = time.time() self.logger.info('K-means++ used %s s', round(end - start, 2)) elif init == 'centers': start = time.time() self.centroids km = KMeans(n_clusters = self.num_labels, n_jobs = -1, random_state=args.seed, init = self.centroids.cpu().numpy()).fit(feats) km_centroids, assign_labels = km.cluster_centers_, km.labels_ end = time.time() self.logger.info('K-means used %s s', round(end - start, 2)) self.centroids = torch.tensor(km_centroids).to(self.device) pseudo_labels = assign_labels.astype(np.long) return outputs, km_centroids, y_true, assign_labels, pseudo_labels def train(self, args, data): self.centroids = None last_preds = None for epoch in trange(int(args.num_train_epochs), desc="Epoch"): self.model.train() for batch in tqdm(self.train_labeled_dataloader, desc="Training(All)"): batch = tuple(t.to(self.device) for t in batch) input_ids, input_mask, segment_ids, label_ids = batch with torch.set_grad_enabled(True): aug_mlp_outputs_a, aug_logits_a = self.model(input_ids, segment_ids, input_mask) aug_mlp_outputs_b, aug_logits_b = self.model(input_ids, segment_ids, input_mask) norm_logits = F.normalize(aug_mlp_outputs_a) norm_aug_logits = F.normalize(aug_mlp_outputs_b) contrastive_feats = torch.cat((norm_logits.unsqueeze(1), norm_aug_logits.unsqueeze(1)), dim = 1) loss_contrast = self.contrast_criterion(contrastive_feats, labels = label_ids, temperature = args.train_temperature, device = self.device) loss = loss_contrast self.l_optimizer.zero_grad() loss.backward() self.l_optimizer.step() self.l_scheduler.step() init_mechanism = 'k-means++' if epoch == 0 else 'centers' outputs, km_centroids, y_true, assign_labels, pseudo_labels = self.clustering(args, init = init_mechanism) current_preds = pseudo_labels delta_label = np.sum(current_preds != last_preds).astype(np.float32) / current_preds.shape[0] last_preds = np.copy(current_preds) if epoch > 0: self.logger.info("***** Epoch: %s *****", str(epoch)) self.logger.info('Training Loss: %f', np.round(tr_loss, 5)) self.logger.info('Delta Label: %f', delta_label) if delta_label < args.tol: self.logger.info('delta_label %s < %f', delta_label, args.tol) self.logger.info('Reached tolerance threshold. Stop training.') break pseudo_train_dataloader = self.get_augment_dataloader(args, self.train_outputs, pseudo_labels) tr_loss = 0 nb_tr_examples, nb_tr_steps = 0, 0 self.model.train() for batch in tqdm(pseudo_train_dataloader, desc="Training(All)"): batch = tuple(t.to(self.device) for t in batch) input_ids, input_mask, segment_ids, label_ids = batch with torch.set_grad_enabled(True): input_ids_a, input_ids_b = self.batch_chunk(input_ids) input_mask_a, input_mask_b = self.batch_chunk(input_mask) segment_ids_a, segment_ids_b = self.batch_chunk(segment_ids) label_ids = torch.chunk(input=label_ids, chunks=2, dim=1)[0][:, 0] aug_mlp_outputs_a, aug_logits_a = self.model(input_ids_a, segment_ids_a, input_mask_a) aug_mlp_outputs_b, aug_logits_b = self.model(input_ids_b, segment_ids_b, input_mask_b) norm_logits = F.normalize(aug_mlp_outputs_a) norm_aug_logits = F.normalize(aug_mlp_outputs_b) loss_ce = 0.5 * (self.criterion(aug_logits_a, label_ids) + self.criterion(aug_logits_b, label_ids)) contrastive_feats = torch.cat((norm_logits.unsqueeze(1), norm_aug_logits.unsqueeze(1)), dim = 1) loss_contrast = self.contrast_criterion(contrastive_feats, labels = label_ids, temperature = args.train_temperature, device = self.device) loss = loss_contrast + loss_ce self.optimizer.zero_grad() loss.backward() if args.grad_clip != -1.0: nn.utils.clip_grad_value_([param for param in self.model.parameters() if param.requires_grad], args.grad_clip) tr_loss += loss.item() nb_tr_examples += input_ids.size(0) nb_tr_steps += 1 self.optimizer.step() self.scheduler.step() tr_loss = tr_loss / nb_tr_steps if args.save_model: save_model(self.model, args.model_output_dir) def test(self, args, data): outputs = self.get_outputs(args, mode = 'test', model = self.model) feats = outputs['feats'] y_true = outputs['y_true'] if args.cluster_num_factor > 1: test_results['estimate_k'] = args.num_labels km = KMeans(n_clusters = self.num_labels, n_jobs = -1, random_state=args.seed, init = self.centroids.cpu().numpy()).fit(feats) y_pred = km.labels_ test_results = clustering_score(y_true, y_pred) cm = confusion_matrix(y_true, y_pred) self.logger.info self.logger.info("***** Test: Confusion Matrix *****") self.logger.info("%s", str(cm)) self.logger.info("***** Test results *****") for key in sorted(test_results.keys()): self.logger.info(" %s = %s", key, str(test_results[key])) test_results['y_true'] = y_true test_results['y_pred'] = y_pred return test_results def get_outputs(self, args, mode, model): if mode == 'eval': dataloader = self.eval_dataloader elif mode == 'test': dataloader = self.test_dataloader elif mode == 'train': dataloader = self.train_dataloader model.eval() total_labels = torch.empty(0,dtype=torch.long).to(self.device) total_preds = torch.empty(0,dtype=torch.long).to(self.device) total_features = torch.empty((0,args.feat_dim)).to(self.device) total_logits = torch.empty((0, self.num_labels)).to(self.device) for batch in tqdm(dataloader, desc="Iteration"): batch = tuple(t.to(self.device) for t in batch) input_ids, input_mask, segment_ids, label_ids = batch with torch.set_grad_enabled(False): pooled_output, logits = model(input_ids, segment_ids, input_mask, feature_ext = True) total_labels = torch.cat((total_labels,label_ids)) total_features = torch.cat((total_features, pooled_output)) total_logits = torch.cat((total_logits, logits)) feats = total_features.cpu().numpy() y_true = total_labels.cpu().numpy() total_probs = F.softmax(total_logits.detach(), dim=1) total_maxprobs, total_preds = total_probs.max(dim = 1) y_pred = total_preds.cpu().numpy() y_logits = total_logits.cpu().numpy() outputs = { 'y_true': y_true, 'y_pred': y_pred, 'logits': y_logits, 'feats': feats } return outputs def load_pretrained_model(self, args, pretrained_model): pretrained_dict = pretrained_model.state_dict() classifier_params = ['mlp_head.bias','mlp_head.0.bias', 'classifier.weight', 'classifier.bias', 'mlp_head.0.weight', 'mlp_head.weight'] pretrained_dict = {k: v for k, v in pretrained_dict.items() if k not in classifier_params} self.model.load_state_dict(pretrained_dict, strict=False) def batch_chunk(self, x): x1, x2 = torch.chunk(input=x, chunks=2, dim=1) x1, x2 = x1.squeeze(1), x2.squeeze(1) return x1, x2 def get_augment_dataloader(self, args, train_outputs, pseudo_labels = None): input_ids = train_outputs['input_ids'] input_mask = train_outputs['input_mask'] segment_ids = train_outputs['segment_ids'] if pseudo_labels is None: pseudo_labels = train_outputs['label_ids'] input_ids_a, input_mask_a = self.generator.random_token_erase(input_ids, input_mask) input_ids_b, input_mask_b = self.generator.random_token_erase(input_ids, input_mask) train_input_ids = torch.cat(([input_ids_a.unsqueeze(1), input_ids_b.unsqueeze(1)]), dim = 1) train_input_mask = torch.cat(([input_mask_a.unsqueeze(1), input_mask_a.unsqueeze(1)]), dim = 1) train_segment_ids = torch.cat(([segment_ids.unsqueeze(1), segment_ids.unsqueeze(1)]), dim = 1) train_label_ids = torch.tensor(pseudo_labels).unsqueeze(1) train_label_ids = torch.cat(([train_label_ids, train_label_ids]), dim = 1) train_data = TensorDataset(train_input_ids, train_input_mask, train_segment_ids, train_label_ids) sampler = RandomSampler(train_data) train_dataloader = DataLoader(train_data, sampler = sampler, batch_size = args.train_batch_size) return train_dataloader