THU-IAR's picture
Upload 198 files
2d06dcc verified
import logging
import torch
import numpy as np
import os
import copy
import logging
import torch.nn.functional as F
import pandas as pd
import random
import math
from sklearn.metrics import silhouette_score
from sklearn.metrics import accuracy_score
from tqdm import trange, tqdm
from losses import loss_map
from utils.functions import save_model, restore_model
from scipy.optimize import linear_sum_assignment
from collections import Counter
from utils.faster_mix_k_means_pytorch import K_Means
from utils.metrics import clustering_accuracy_score
class PretrainDTCManager:
def __init__(self, args, data, model, logger_name = 'Discovery'):
self.logger = logging.getLogger(logger_name)
loader = data.dataloader
self.train_dataloader, self.eval_dataloader, self.test_dataloader = \
loader.train_labeled_outputs['loader'], loader.eval_outputs['loader'], loader.test_outputs['loader']
args.num_labels = data.n_known_cls
self.set_model_optimizer(args, data, model)
self.loss_fct = loss_map[args.loss_fct]
if args.pretrain:
self.logger.info('Pre-raining start...')
self.train(args, data)
self.logger.info('Pre-training finished...')
else:
self.model = restore_model(self.model, os.path.join(args.method_output_dir, 'pretrain'))
if args.cluster_num_factor > 1.0:
self.num_labels = data.num_labels
self.num_labels = self.predict_k(args, data, model)
def set_model_optimizer(self, args, data, model):
self.model = model.set_model(args, data, 'bert', args.freeze_bert_parameters)
self.optimizer, self.scheduler = model.set_optimizer(self.model, len(data.dataloader.train_labeled_examples), args.train_batch_size, \
args.num_pretrain_epochs, args.lr_pre, args.warmup_proportion)
self.device = model.device
def predict_k(self, args, data, model):
loader = data.dataloader
self.dtc_train_labeled_dataloader, self.dtc_train_unlabeled_dataloader, self.dtc_eval_dataloader, self.dtc_val_labeled_dataloader = \
loader.train_labeled_outputs_dtc['loader'], loader.train_unlabeled_outputs_dtc['loader'], loader.eval_outputs_dtc['loader'], loader.val_labeled_outputs_dtc['loader']
self.predict_model = model.set_model(args, data, 'bert')
self.predict_optimizer, self.predict_scheduler = model.set_optimizer(self.predict_model, len(data.dataloader.train_labeled_examples_dtc), args.train_batch_size, \
args.num_pretrain_epochs, args.lr_pre, args.warmup_proportion)
self.logger.info("***** Running predict k *****")
self.predict_model.to(self.device)
wait = 0
best_model = None
best_eval_score = 0
patient = 1
acc_best = 0
for epoch in trange(int(args.num_pretrain_epochs), desc="Epoch"):
self.predict_model.train()
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
for step, batch in enumerate(tqdm(self.dtc_train_labeled_dataloader, desc="Iteration (labeled)")):
batch = tuple(t.to(self.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
with torch.set_grad_enabled(True):
loss = self.predict_model(input_ids, segment_ids, input_mask, label_ids, loss_fct = self.loss_fct, mode = "train")
loss.backward()
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
self.predict_optimizer.step()
self.predict_scheduler.step()
self.predict_optimizer.zero_grad()
loss = tr_loss / nb_tr_steps
self.logger.info("loss: %s", str(loss))
self.predict_model.eval()
total_logits = torch.empty((0, data.n_known_cls)).to(self.device)
total_labels = torch.empty(0,dtype=torch.long).to(self.device)
for batch in tqdm(self.dtc_eval_dataloader, desc="Extracting representation"):
batch = tuple(t.to(self.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
with torch.no_grad():
logits, _ = self.predict_model(input_ids, segment_ids, input_mask)
total_logits = torch.cat((total_logits, logits))
total_labels = torch.cat((total_labels, label_ids))
probs, preds = F.softmax(total_logits, dim = 1).max(dim = 1)
y_pred = preds.cpu().numpy()
y_true = total_labels.cpu().numpy()
acc = clustering_accuracy_score(y_true, y_pred)
self.logger.info("eval_results: %s", str(acc))
if acc > acc_best:
model_best = copy.deepcopy(self.predict_model)
wait = 0
acc_best = acc
else:
wait += 1
if wait >= patient:
self.predict_model = model_best
break
max_cand_k = self.num_labels
self.predict_model.eval()
u_labels = torch.empty(0, dtype=torch.long).to(self.device)
u_features = torch.empty((0, args.feat_dim)).to(self.device)
with torch.set_grad_enabled(False):
for batch in tqdm(self.dtc_train_unlabeled_dataloader, desc="Extracting Features"):
batch = tuple(t.to(self.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
features = self.predict_model(input_ids, segment_ids, input_mask, feature_ext = True)
u_features = torch.cat((u_features, features))
u_labels = torch.cat((u_labels, label_ids))
u_feats = u_features.cpu().numpy()
u_labels = u_labels.cpu().numpy()
self.predict_model.eval()
l_labels = torch.empty(0, dtype=torch.long).to(self.device)
l_features = torch.empty((0, args.feat_dim)).to(self.device)
with torch.set_grad_enabled(False):
for batch in tqdm(self.dtc_val_labeled_dataloader, desc="Extracting Features"):
batch = tuple(t.to(self.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
features = self.predict_model(input_ids, segment_ids, input_mask, feature_ext = True)
l_features = torch.cat((l_features, features))
l_labels = torch.cat((l_labels, label_ids))
l_feats = l_features.cpu().numpy()
l_targets = l_labels.cpu().numpy()
l_classes = set(l_targets)
split_ratio = 0.75
num_lt_cls = int(round(len(l_classes) * split_ratio))
lt_classes = set(random.sample(l_classes, num_lt_cls))
lv_classes = l_classes - lt_classes
lt_feats = np.empty((0, l_feats.shape[1]))
lt_targets = np.empty(0)
for c in lt_classes:
lt_feats = np.vstack((lt_feats, l_feats[l_targets==c]))
lt_targets = np.append(lt_targets, l_targets[l_targets==c])
lv_feats = np.empty((0, l_feats.shape[1]))
lv_targets = np.empty(0)
for c in lv_classes:
lv_feats = np.vstack((lv_feats, l_feats[l_targets==c]))
lv_targets = np.append(lv_targets, l_targets[l_targets==c])
cand_k = np.arange(max_cand_k)
cvi_list = np.zeros(len(cand_k))
acc_list = np.zeros(len(cand_k))
u_num = len(u_labels)
l_num = len(l_targets)
cat_pred_list = np.zeros([len(cand_k),u_num+l_num])
self.logger.info("estimating K ...")
from sklearn.metrics import silhouette_score
num_k = 10
cnt = 0
last_k = -1
num_val_cls = data.dataloader.num_val_cls
for i in range(len(cand_k)):
cvi_list[i], cat_pred_i = self.labeled_val_fun(np.concatenate((lv_feats, u_feats)), lt_feats, lt_targets, cand_k[i]+num_val_cls)
cat_pred_list[i, :] = cat_pred_i
print(cat_pred_i[len(lt_targets): len(lt_targets)+len(lv_targets)])
lv_targets = np.array([int(num) for num in lv_targets])
print(lv_targets)
acc_list[i] = clustering_accuracy_score(lv_targets, cat_pred_i[len(lt_targets): len(lt_targets)+len(lv_targets)])
best_k = self.get_best_k(cvi_list[:i+1], acc_list[:i+1], cat_pred_list[:i+1], l_num) + data.n_known_cls
if best_k == last_k:
cnt += 1
if cnt >= num_k:
break
else:
last_k = best_k
cnt=0
self.logger.info("current best K: %s", str(best_k))
self.logger.info("best K: %s", str(best_k))
return best_k
def get_best_k(self, cvi_list, acc_list, cat_pred_list, l_num):
min_max_ratio = 0.1
idx_cvi = np.max(np.argwhere(cvi_list==np.max(cvi_list)))
idx_acc = np.max(np.argwhere(acc_list==np.max(acc_list)))
idx_best = int(math.ceil((idx_cvi+idx_acc)*1.0/2))
cat_pred = cat_pred_list[idx_best, :]
cnt_cat = Counter(cat_pred.tolist())
cnt_l = Counter(cat_pred[:l_num].tolist())
cnt_ul = Counter(cat_pred[l_num:].tolist())
bin_cat = [x[1] for x in sorted(cnt_cat.items())]
bin_l = [x[1] for x in sorted(cnt_l.items())]
bin_ul = [x[1] for x in sorted(cnt_ul.items())]
best_k = np.sum(np.array(bin_ul)/np.max(bin_ul).astype(float)>min_max_ratio)
return best_k
def labeled_val_fun(self, u_feats, l_feats, l_targets, k):
if self.device=='cuda':
torch.cuda.empty_cache()
l_num=len(l_targets)
kmeans = K_Means(k, pairwise_batch_size=256)
kmeans.fit_mix(torch.from_numpy(u_feats).to(self.device), torch.from_numpy(l_feats).to(self.device), torch.from_numpy(l_targets).to(self.device))
cat_pred = kmeans.labels_.cpu().numpy()
u_pred = cat_pred[l_num:]
silh_score = silhouette_score(u_feats, u_pred)
return silh_score, cat_pred
def train(self, args, data):
wait = 0
best_model = None
best_eval_score = 0
for epoch in trange(int(args.num_pretrain_epochs), desc="Epoch"):
self.model.train()
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
for step, batch in enumerate(tqdm(self.train_dataloader, desc="Iteration (labeled)")):
batch = tuple(t.to(self.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
with torch.set_grad_enabled(True):
loss = self.model(input_ids, segment_ids, input_mask, label_ids, loss_fct = self.loss_fct, mode = "train")
loss.backward()
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
loss = tr_loss / nb_tr_steps
eval_true, eval_pred = self.get_outputs(args, mode = 'eval')
eval_score = accuracy_score(eval_true, eval_pred)
eval_results = {
'train_loss': loss,
'eval_score': eval_score,
'best_score':best_eval_score,
}
self.logger.info("***** Epoch: %s: Eval results *****", str(epoch + 1))
for key in sorted(eval_results.keys()):
self.logger.info(" %s = %s", key, str(eval_results[key]))
if eval_score > best_eval_score:
best_model = copy.deepcopy(self.model)
wait = 0
best_eval_score = eval_score
elif eval_score > 0:
wait += 1
if wait >= args.wait_patient:
break
self.model = best_model
if args.save_model:
pretrained_model_dir = os.path.join(args.method_output_dir, 'pretrain')
if not os.path.exists(pretrained_model_dir):
os.makedirs(pretrained_model_dir)
save_model(self.model, pretrained_model_dir)
def get_outputs(self, args, mode = 'eval', get_feats = False):
if mode == 'eval':
dataloader = self.eval_dataloader
self.model.eval()
total_labels = torch.empty(0,dtype=torch.long).to(self.device)
total_preds = torch.empty(0,dtype=torch.long).to(self.device)
total_logits = torch.empty((0,args.num_labels)).to(self.device)
total_features = torch.empty((0,args.num_labels)).to(self.device)
for batch in tqdm(dataloader, desc="Iteration"):
batch = tuple(t.to(self.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
with torch.set_grad_enabled(False):
logits, probs = self.model(input_ids, segment_ids, input_mask)
total_labels = torch.cat((total_labels,label_ids))
total_logits = torch.cat((total_logits, logits))
total_features = torch.cat((total_features, logits))
if get_feats:
feats = total_features.cpu().numpy()
return feats
else:
total_probs = F.softmax(total_logits.detach(), dim=1)
total_maxprobs, total_preds = total_probs.max(dim = 1)
y_pred = total_preds.cpu().numpy()
y_true = total_labels.cpu().numpy()
return y_true, y_pred