THU-IAR's picture
Upload 198 files
2d06dcc verified
import logging
import numpy as np
import torch
import torch.nn as nn
from utils.metrics import clustering_score
from sklearn.metrics import confusion_matrix
from tqdm import trange, tqdm
from sklearn.cluster import KMeans
from torch.utils.data import (DataLoader, RandomSampler, TensorDataset)
from utils.functions import save_model
from losses import contrastive_loss
class CCmanager:
def __init__(self, args, data, model, logger_name = 'Discovery'):
self.logger = logging.getLogger(logger_name)
self.device = model.device
self.num_labels = data.num_labels
loader = data.dataloader
self.train_dataloader, self.test_dataloader = \
loader.train_outputs['loader'], loader.test_outputs['loader']
self.train_input_ids, self.train_input_mask, self.train_segment_ids = \
loader.train_outputs['input_ids'], loader.train_outputs['input_mask'], loader.train_outputs['segment_ids']
self.augdataloader = self.get_augment_dataloader(args)
self.set_model_optimizer(args, data, model)
self.instance_temperature = 0.7
self.cluster_temperature = 1.0
self.criterion_instance = contrastive_loss.InstanceLoss(args.train_batch_size, self.instance_temperature, self.device)
self.criterion_cluster = contrastive_loss.ClusterLoss(self.num_labels, self.cluster_temperature, self.device)
def set_model_optimizer(self, args, data, model):
self.model = model.set_model(args, data, 'bert', args.freeze_bert_parameters)
self.optimizer , self.scheduler = model.set_optimizer(self.model, data.dataloader.num_train_examples, args.train_batch_size, \
args.num_train_epochs, args.lr, args.warmup_proportion)
self.device = model.device
def batch_chunk(self, x):
x1, x2 = torch.chunk(input=x, chunks=2, dim=1)
x1, x2 = x1.squeeze(1), x2.squeeze(1)
return x1, x2
def train(self, args, data):
self.logger.info('CC training starts...')
for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
tr_loss, nb_tr_steps = 0, 0
self.model.train()
for batch in tqdm(self.augdataloader, desc="Training(All)"):
batch = tuple(t.to(self.device) for t in batch)
input_ids, input_mask, segment_ids = batch
with torch.set_grad_enabled(True):
input_ids_a, input_ids_b = self.batch_chunk(input_ids)
input_mask_a, input_mask_b = self.batch_chunk(input_mask)
segment_ids_a, segment_ids_b = self.batch_chunk(segment_ids)
x_i = self.model(input_ids_a, segment_ids_a, input_mask_a)
x_j = self.model(input_ids_b, segment_ids_b, input_mask_b)
z_i, z_j, c_i, c_j = self.model.get_features(x_i, x_j)
loss_instance = self.criterion_instance(z_i, z_j)
loss_cluster = self.criterion_cluster(c_i, c_j)
loss = loss_instance + loss_cluster
self.optimizer.zero_grad()
loss.backward()
tr_loss += loss.item()
nb_tr_steps += 1
self.optimizer.step()
self.scheduler.step()
train_loss = tr_loss / nb_tr_steps
self.logger.info("***** Epoch: %s: train results *****", str(epoch))
self.logger.info(" train_loss = %s", str(train_loss))
self.logger.info('CC training finished...')
if args.save_model:
save_model(self.model, args.model_output_dir)
def test(self, args, data):
feats, y_true = self.get_outputs(args, mode = 'test')
km = KMeans(n_clusters = self.num_labels, random_state=args.seed).fit(feats)
y_pred = km.labels_
test_results = clustering_score(y_true, y_pred)
cm = confusion_matrix(y_true, y_pred)
self.logger.info
self.logger.info("***** Test: Confusion Matrix *****")
self.logger.info("%s", str(cm))
self.logger.info("***** Test results *****")
for key in sorted(test_results.keys()):
self.logger.info(" %s = %s", key, str(test_results[key]))
test_results['y_true'] = y_true
test_results['y_pred'] = y_pred
return test_results
def get_outputs(self, args, mode):
if mode == 'test':
dataloader = self.test_dataloader
self.model.eval()
total_labels = torch.empty(0,dtype=torch.long).to(self.device)
total_features = torch.empty((0,args.feat_dim)).to(self.device)
for batch in tqdm(dataloader, desc="Iteration"):
batch = tuple(t.to(self.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
with torch.set_grad_enabled(False):
pooled_output = self.model(input_ids, segment_ids, input_mask)
total_labels = torch.cat((total_labels,label_ids))
total_features = torch.cat((total_features, pooled_output))
feats = total_features.cpu().numpy()
y_true = total_labels.cpu().numpy()
return feats, y_true
def get_augment_dataloader(self, args):
train_input_ids = self.train_input_ids.unsqueeze(1)
train_input_mask = self.train_input_mask.unsqueeze(1)
train_segment_ids = self.train_segment_ids.unsqueeze(1)
train_input_ids = torch.cat(([train_input_ids, train_input_ids]), dim = 1)
train_input_mask = torch.cat(([train_input_mask, train_input_mask]), dim = 1)
train_segment_ids = torch.cat(([train_segment_ids, train_segment_ids]), dim = 1)
train_data = TensorDataset(train_input_ids, train_input_mask, train_segment_ids)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler = train_sampler, batch_size = args.train_batch_size, drop_last=True)
return train_dataloader