| import logging | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from utils.metrics import clustering_score | |
| from sklearn.metrics import confusion_matrix | |
| from tqdm import trange, tqdm | |
| from sklearn.cluster import KMeans | |
| from torch.utils.data import (DataLoader, RandomSampler, TensorDataset) | |
| from utils.functions import save_model | |
| from losses import contrastive_loss | |
| class CCmanager: | |
| def __init__(self, args, data, model, logger_name = 'Discovery'): | |
| self.logger = logging.getLogger(logger_name) | |
| self.device = model.device | |
| self.num_labels = data.num_labels | |
| loader = data.dataloader | |
| self.train_dataloader, self.test_dataloader = \ | |
| loader.train_outputs['loader'], loader.test_outputs['loader'] | |
| self.train_input_ids, self.train_input_mask, self.train_segment_ids = \ | |
| loader.train_outputs['input_ids'], loader.train_outputs['input_mask'], loader.train_outputs['segment_ids'] | |
| self.augdataloader = self.get_augment_dataloader(args) | |
| self.set_model_optimizer(args, data, model) | |
| self.instance_temperature = 0.7 | |
| self.cluster_temperature = 1.0 | |
| self.criterion_instance = contrastive_loss.InstanceLoss(args.train_batch_size, self.instance_temperature, self.device) | |
| self.criterion_cluster = contrastive_loss.ClusterLoss(self.num_labels, self.cluster_temperature, self.device) | |
| def set_model_optimizer(self, args, data, model): | |
| self.model = model.set_model(args, data, 'bert', args.freeze_bert_parameters) | |
| self.optimizer , self.scheduler = model.set_optimizer(self.model, data.dataloader.num_train_examples, args.train_batch_size, \ | |
| args.num_train_epochs, args.lr, args.warmup_proportion) | |
| self.device = model.device | |
| def batch_chunk(self, x): | |
| x1, x2 = torch.chunk(input=x, chunks=2, dim=1) | |
| x1, x2 = x1.squeeze(1), x2.squeeze(1) | |
| return x1, x2 | |
| def train(self, args, data): | |
| self.logger.info('CC training starts...') | |
| for epoch in trange(int(args.num_train_epochs), desc="Epoch"): | |
| tr_loss, nb_tr_steps = 0, 0 | |
| self.model.train() | |
| for batch in tqdm(self.augdataloader, desc="Training(All)"): | |
| batch = tuple(t.to(self.device) for t in batch) | |
| input_ids, input_mask, segment_ids = batch | |
| with torch.set_grad_enabled(True): | |
| input_ids_a, input_ids_b = self.batch_chunk(input_ids) | |
| input_mask_a, input_mask_b = self.batch_chunk(input_mask) | |
| segment_ids_a, segment_ids_b = self.batch_chunk(segment_ids) | |
| x_i = self.model(input_ids_a, segment_ids_a, input_mask_a) | |
| x_j = self.model(input_ids_b, segment_ids_b, input_mask_b) | |
| z_i, z_j, c_i, c_j = self.model.get_features(x_i, x_j) | |
| loss_instance = self.criterion_instance(z_i, z_j) | |
| loss_cluster = self.criterion_cluster(c_i, c_j) | |
| loss = loss_instance + loss_cluster | |
| self.optimizer.zero_grad() | |
| loss.backward() | |
| tr_loss += loss.item() | |
| nb_tr_steps += 1 | |
| self.optimizer.step() | |
| self.scheduler.step() | |
| train_loss = tr_loss / nb_tr_steps | |
| self.logger.info("***** Epoch: %s: train results *****", str(epoch)) | |
| self.logger.info(" train_loss = %s", str(train_loss)) | |
| self.logger.info('CC training finished...') | |
| if args.save_model: | |
| save_model(self.model, args.model_output_dir) | |
| def test(self, args, data): | |
| feats, y_true = self.get_outputs(args, mode = 'test') | |
| km = KMeans(n_clusters = self.num_labels, random_state=args.seed).fit(feats) | |
| y_pred = km.labels_ | |
| test_results = clustering_score(y_true, y_pred) | |
| cm = confusion_matrix(y_true, y_pred) | |
| self.logger.info | |
| self.logger.info("***** Test: Confusion Matrix *****") | |
| self.logger.info("%s", str(cm)) | |
| self.logger.info("***** Test results *****") | |
| for key in sorted(test_results.keys()): | |
| self.logger.info(" %s = %s", key, str(test_results[key])) | |
| test_results['y_true'] = y_true | |
| test_results['y_pred'] = y_pred | |
| return test_results | |
| def get_outputs(self, args, mode): | |
| if mode == 'test': | |
| dataloader = self.test_dataloader | |
| self.model.eval() | |
| total_labels = torch.empty(0,dtype=torch.long).to(self.device) | |
| total_features = torch.empty((0,args.feat_dim)).to(self.device) | |
| for batch in tqdm(dataloader, desc="Iteration"): | |
| batch = tuple(t.to(self.device) for t in batch) | |
| input_ids, input_mask, segment_ids, label_ids = batch | |
| with torch.set_grad_enabled(False): | |
| pooled_output = self.model(input_ids, segment_ids, input_mask) | |
| total_labels = torch.cat((total_labels,label_ids)) | |
| total_features = torch.cat((total_features, pooled_output)) | |
| feats = total_features.cpu().numpy() | |
| y_true = total_labels.cpu().numpy() | |
| return feats, y_true | |
| def get_augment_dataloader(self, args): | |
| train_input_ids = self.train_input_ids.unsqueeze(1) | |
| train_input_mask = self.train_input_mask.unsqueeze(1) | |
| train_segment_ids = self.train_segment_ids.unsqueeze(1) | |
| train_input_ids = torch.cat(([train_input_ids, train_input_ids]), dim = 1) | |
| train_input_mask = torch.cat(([train_input_mask, train_input_mask]), dim = 1) | |
| train_segment_ids = torch.cat(([train_segment_ids, train_segment_ids]), dim = 1) | |
| train_data = TensorDataset(train_input_ids, train_input_mask, train_segment_ids) | |
| train_sampler = RandomSampler(train_data) | |
| train_dataloader = DataLoader(train_data, sampler = train_sampler, batch_size = args.train_batch_size, drop_last=True) | |
| return train_dataloader | |