import torch import torch.nn.functional as F import numpy as np import copy import logging from sklearn.metrics import confusion_matrix from sklearn.cluster import KMeans from tqdm import trange, tqdm from utils.functions import set_seed from utils.metrics import clustering_score from utils.functions import restore_model, save_model def target_distribution(q): weight = q ** 2 / q.sum(0) return (weight.T / weight.sum(1)).T class CDACPlusManager: def __init__(self, args, data, model, logger_name = 'Discovery'): self.logger = logging.getLogger(logger_name) set_seed(args.seed) loader = data.dataloader self.train_dataloader, self.eval_dataloader, self.test_dataloader = \ loader.train_outputs['loader'], loader.eval_outputs['loader'], loader.test_outputs['loader'] self.train_labeled_dataloader = loader.train_labeled_outputs['loader'] self.train_unlabeled_dataloader = loader.train_unlabeled_outputs['loader'] self.model = model.set_model(args, data, 'bert') self.optimizer1 , self.scheduler1 = model.set_optimizer(self.model, data.dataloader.num_train_examples, args.train_batch_size, \ args.num_train_epochs, args.lr, args.warmup_proportion) self.optimizer2 , self.scheduler2 = model.set_optimizer(self.model, data.dataloader.num_train_examples, args.train_batch_size, \ args.num_refine_epochs, args.lr, args.warmup_proportion) self.device = model.device if not args.train: self.model = restore_model(self.model, args.model_output_dir) def initialize_centroids(self, args, data): self.logger.info("Initialize centroids...") feats = self.get_outputs(args, mode = 'train_unlabeled', get_feats = True) km = KMeans(n_clusters=data.num_labels, n_jobs=-1, random_state=args.seed) km.fit(feats) self.logger.info("Initialization finished...") self.model.cluster_layer.data = torch.tensor(km.cluster_centers_).to(self.device) def train(self, args, data): self.logger.info('Pairwise-similarity Learning begin...') u = args.u l = args.l eta = 0 eval_pred_last = np.zeros_like(data.dataloader.eval_examples) for epoch in trange(int(args.num_train_epochs), desc="Epoch"): tr_loss, nb_tr_examples, nb_tr_steps = 0, 0, 0 self.model.train() for step, batch in enumerate(tqdm(self.train_labeled_dataloader, desc="Iteration (labeled)")): batch = tuple(t.to(self.device) for t in batch) input_ids, input_mask, segment_ids, label_ids = batch loss = self.model(input_ids, segment_ids, input_mask, label_ids, u_threshold = u, l_threshold = l, mode = 'train') loss.backward() tr_loss += loss.item() nb_tr_examples += input_ids.size(0) nb_tr_steps += 1 self.optimizer1.step() self.scheduler1.step() self.optimizer1.zero_grad() train_labeled_loss = tr_loss / nb_tr_steps tr_loss, nb_tr_examples, nb_tr_steps = 0, 0, 0 for step, batch in enumerate(tqdm(self.train_dataloader, desc="Iteration (all train)")): batch = tuple(t.to(self.device) for t in batch) input_ids, input_mask, segment_ids, label_ids = batch loss = self.model(input_ids, segment_ids, input_mask, label_ids, u_threshold = u, l_threshold = l, mode = 'train', semi = True) loss.backward() tr_loss += loss.item() nb_tr_examples += input_ids.size(0) nb_tr_steps += 1 self.optimizer1.step() self.scheduler1.step() self.optimizer1.zero_grad() train_loss = tr_loss / nb_tr_steps eval_true, eval_pred = self.get_outputs(args, mode = 'eval') eval_score = clustering_score(eval_true, eval_pred)['NMI'] delta_label = np.sum(eval_pred != eval_pred_last).astype(np.float32) / eval_pred.shape[0] eval_pred_last = np.copy(eval_pred) train_results = { 'u_threshold': round(u, 4), 'l_threshold': round(l, 4), 'train_labeled_loss': train_labeled_loss, 'train_loss': train_loss, 'delta_label': delta_label, 'eval_score': eval_score } self.logger.info("***** Epoch: %s: Eval results *****", str(epoch)) for key in sorted(train_results.keys()): self.logger.info(" %s = %s", key, str(train_results[key])) eta += 1.1 * 0.009 u = 0.95 - eta l = 0.455 + eta * 0.1 if u < l: break self.logger.info('Pairwise-similarity Learning finished...') self.refine(args, data) def refine(self, args, data): self.logger.info('Cluster refining begin...') self.initialize_centroids(args, data) best_model = None wait = 0 train_preds_last = None best_eval_score = 0 for epoch in range(args.num_refine_epochs): eval_true, eval_pred = self.get_outputs(args, mode = 'eval') eval_score = clustering_score(eval_true, eval_pred)['NMI'] if eval_score > best_eval_score: best_model = copy.deepcopy(self.model) wait = 0 best_eval_score = eval_score self.model = best_model else: wait += 1 if wait > args.wait_patient: break train_pred_logits = self.get_outputs(args, mode = 'train', get_logits = True) p_target = target_distribution(train_pred_logits) train_preds = train_pred_logits.argmax(1) delta_label = np.sum(train_preds != train_preds_last).astype(np.float32) / train_preds.shape[0] train_preds_last = np.copy(train_preds) if epoch > 0 and delta_label < 0.001: self.logger.info('Break at epoch: %s and delta_label: %f.', str(epoch + 1), round(delta_label, 2)) break # Fine-tuning with auxiliary distribution self.model.train() tr_loss, nb_tr_examples, nb_tr_steps = 0, 0, 0 for step, batch in enumerate(self.train_dataloader): batch = tuple(t.to(self.device) for t in batch) input_ids, input_mask, segment_ids, label_ids = batch feats, logits = self.model(input_ids, segment_ids, input_mask, mode='finetune') kl_loss = F.kl_div(logits.log(), torch.Tensor(p_target[step * args.train_batch_size: (step + 1) * args.train_batch_size]).to(self.device)) kl_loss.backward() tr_loss += kl_loss.item() nb_tr_examples += input_ids.size(0) nb_tr_steps += 1 self.optimizer2.step() self.scheduler2.step() self.optimizer2.zero_grad() train_loss = tr_loss / nb_tr_steps eval_results = { 'kl_loss': round(train_loss, 4), 'delta_label': delta_label.round(4), 'eval_score': round(eval_score, 2), 'best_eval_score': round(best_eval_score, 2) } self.logger.info("***** Epoch: %s: Eval results *****", str(epoch)) for key in sorted(eval_results.keys()): self.logger.info(" %s = %s", key, str(eval_results[key])) self.logger.info('Cluster refining finished...') if args.save_model: save_model(self.model, args.model_output_dir) def get_outputs(self, args, mode = 'eval', get_feats = False, get_logits = False): if mode == 'eval': dataloader = self.eval_dataloader elif mode == 'test': dataloader = self.test_dataloader elif mode == 'train_unlabeled': dataloader = self.train_unlabeled_dataloader elif mode == 'train': dataloader = self.train_dataloader self.model.eval() total_labels = torch.empty(0,dtype=torch.long).to(self.device) total_preds = torch.empty(0,dtype=torch.long).to(self.device) total_features = torch.empty((0, args.num_labels)).to(self.device) total_logits = torch.empty((0, args.num_labels)).to(self.device) for batch in tqdm(dataloader, desc="Iteration"): batch = tuple(t.to(self.device) for t in batch) input_ids, input_mask, segment_ids, label_ids = batch with torch.set_grad_enabled(False): pooled_output, logits = self.model(input_ids, segment_ids, input_mask) total_labels = torch.cat((total_labels, label_ids)) total_features = torch.cat((total_features, pooled_output)) total_logits = torch.cat((total_logits, logits)) if get_feats: feats = total_features.cpu().numpy() return feats elif get_logits: logits = total_logits.cpu().numpy() return logits else: total_preds = total_logits.argmax(1) y_pred = total_preds.cpu().numpy() y_true = total_labels.cpu().numpy() return y_true, y_pred def test(self, args, data): y_true, y_pred = self.get_outputs(args, mode = 'test') test_results = clustering_score(y_true, y_pred) cm = confusion_matrix(y_true,y_pred) self.logger.info self.logger.info("***** Test: Confusion Matrix *****") self.logger.info("%s", str(cm)) self.logger.info("***** Test results *****") for key in sorted(test_results.keys()): self.logger.info(" %s = %s", key, str(test_results[key])) test_results['y_true'] = y_true test_results['y_pred'] = y_pred return test_results