import torch import torch.nn.functional as F import logging import os import torch.nn as nn import numpy as np import copy from sklearn.cluster import KMeans from sklearn.metrics import confusion_matrix from tqdm import trange, tqdm from losses import loss_map from utils.functions import save_model, restore_model, MemoryBank, fill_memory_bank, view_generator, set_seed from utils.neighbor_dataset import NeighborsDataset from torch.utils.data import DataLoader from .pretrain import PretrainMTP_CLNNManager from utils.metrics import clustering_score from transformers import AutoTokenizer class MTP_CLNNManager: def __init__(self, args, data, model, logger_name = 'Discovery'): self.logger = logging.getLogger(logger_name) pretrain_manager = PretrainMTP_CLNNManager(args, data, model) set_seed(args.seed) self.logger = logging.getLogger(logger_name) loader = data.dataloader self.train_dataloader, self.eval_dataloader, self.test_dataloader = \ loader.train_outputs['loader'], loader.eval_outputs['loader'], loader.test_outputs['loader'] self.train_dataset = loader.train_outputs['semi_data'] self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_bert_model) self.generator = view_generator(self.tokenizer, args) self.temp=0.07 if args.pretrain: self.pretrained_model = pretrain_manager.model self.set_model_optimizer(args, data, model) self.num_labels = data.num_labels self.load_pretrained_model(self.pretrained_model) else: self.pretrained_model = restore_model(pretrain_manager.model, os.path.join(args.method_output_dir, 'pretrain')) self.set_model_optimizer(args, data, model) self.num_labels = data.num_labels self.model = restore_model(self.model, args.model_output_dir) topk = {'banking': 50, 'clinc': 60, 'stackoverflow': 300} if args.cluster_num_factor > 1: self.logger.info('num_labels is %s, Length of train_dataset is %s', str(self.num_labels), str(len(self.train_dataset))) args.topk = int((len(self.train_dataset) * 0.5) / self.num_labels) else: args.topk = topk[args.dataset] self.logger.info('Topk for %s is %s', str(args.dataset), str(args.topk)) def set_model_optimizer(self, args, data, model): if args.dataset == 'stackoverflow': args.lr = 1e-6 args.backbone = 'bert_MTP' self.model = model.set_model(args, data, 'bert', args.freeze_train_bert_parameters) self.optimizer , self.scheduler = model.set_optimizer(self.model, len(data.dataloader.train_examples), args.train_batch_size, \ args.num_train_epochs, args.lr, args.warmup_proportion) self.device = model.device self.criterion = self.model.loss_cl def train(self, args, data): indices = self.get_neighbor_inds(args, data) self.get_neighbor_dataset(args, data, indices) best_eval_score = 0 for epoch in trange(int(args.num_train_epochs), desc="Epoch"): self.model.train() tr_loss = 0 nb_tr_examples, nb_tr_steps = 0, 0 for batch in tqdm(self.train_dataloader_neighbor, desc="Iteration"): anchor = tuple(t.to(self.device) for t in batch["anchor"]) neighbor = tuple(t.to(self.device) for t in batch["neighbor"]) pos_neighbors = batch["possible_neighbors"] data_inds = batch["index"] adjacency = self.get_adjacency(args, data_inds, pos_neighbors, batch["target"]) # (bz,bz) X_an = {"input_ids":self.generator.random_token_replace(anchor[0].cpu()).to(self.device), "attention_mask":anchor[1], "token_type_ids":anchor[2]} X_ng = {"input_ids":self.generator.random_token_replace(neighbor[0].cpu()).to(self.device), "attention_mask":neighbor[1], "token_type_ids":neighbor[2]} with torch.set_grad_enabled(True): f_pos = torch.stack([self.model(X_an)["features"], self.model(X_ng)["features"]], dim=1) loss = self.criterion(f_pos, mask=adjacency, temperature=self.temp, device = self.device) tr_loss += loss.item() loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), args.grad_clip) self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() nb_tr_examples += anchor[0].size(0) nb_tr_steps += 1 loss = tr_loss / nb_tr_steps self.logger.info("***** Epoch: %s *****", str(epoch)) self.logger.info('Training Loss: %f', np.round(loss, 5)) if ((epoch + 1) % args.update_per_epoch) == 0: self.logger.info("Update neighbors...") indices = self.get_neighbor_inds(args, data) self.get_neighbor_dataset(args, data, indices) if args.save_model: save_model(self.model, args.model_output_dir) def test(self, args, data): feats, y_true = self.get_outputs(args, mode = 'test', model = self.model, get_feats = True) km = KMeans(n_clusters = self.num_labels, random_state=args.seed).fit(feats) y_pred = km.labels_ test_results = clustering_score(y_true, y_pred) cm = confusion_matrix(y_true, y_pred) self.logger.info self.logger.info("***** Test: Confusion Matrix *****") self.logger.info("%s", str(cm)) self.logger.info("***** Test results *****") for key in sorted(test_results.keys()): self.logger.info(" %s = %s", key, str(test_results[key])) test_results['y_true'] = y_true test_results['y_pred'] = y_pred return test_results def get_outputs(self, args, mode, model, get_feats = False): if mode == 'eval': dataloader = self.eval_dataloader elif mode == 'test': dataloader = self.test_dataloader elif mode == 'train': dataloader = self.train_dataloader model.eval() total_labels = torch.empty(0,dtype=torch.long).to(self.device) total_features = torch.empty((0,args.feat_dim)).to(self.device) for batch in tqdm(dataloader, desc="Iteration"): batch = tuple(t.to(self.device) for t in batch) input_ids, input_mask, segment_ids, label_ids = batch X = {"input_ids":input_ids, "attention_mask": input_mask, "token_type_ids": segment_ids} with torch.set_grad_enabled(False): pooled_output = model(X)["hidden_states"] total_labels = torch.cat((total_labels,label_ids)) total_features = torch.cat((total_features, pooled_output)) if get_feats: feats = total_features.cpu().numpy() y_true = total_labels.cpu().numpy() return feats, y_true def load_pretrained_model(self, pretrained_model): pretrained_dict = pretrained_model.state_dict() self.model.load_state_dict(pretrained_dict, strict=False) def get_neighbor_dataset(self, args, data, indices): """convert indices to dataset""" dataset = NeighborsDataset(self.train_dataset, indices) self.train_dataloader_neighbor = DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True) def get_neighbor_inds(self, args, data): """get indices of neighbors""" memory_bank = MemoryBank(len(self.train_dataset), args.feat_dim, self.num_labels, 0.1) fill_memory_bank(self, self.train_dataloader, self.model, memory_bank) indices = memory_bank.mine_nearest_neighbors(args.topk, args.gpu_id ,calculate_accuracy=False) return indices def get_adjacency(self, args, inds, neighbors, targets): """get adjacency matrix""" adj = torch.zeros(inds.shape[0], inds.shape[0]) for b1, n in enumerate(neighbors): adj[b1][b1] = 1 for b2, j in enumerate(inds): if j in n: adj[b1][b2] = 1 if (targets[b1] == targets[b2]) and (targets[b1]>0) and (targets[b2]>0): adj[b1][b2] = 1 return adj