THU-IAR's picture
Upload 198 files
2d06dcc verified
from importlib import import_module
import torch
import torch.nn.functional as F
import numpy as np
import copy
import logging
from losses import loss_map
from torch import nn
from datetime import datetime
from sklearn.metrics import confusion_matrix, accuracy_score
from tqdm import trange, tqdm
from utils.functions import restore_model, save_model
from utils.metrics import F_measure
from .openmax_utils import recalibrate_scores, weibull_tailfitting, compute_distance
class OpenMaxManager:
def __init__(self, args, data, model, logger_name = 'Detection'):
self.logger = logging.getLogger(logger_name)
self.set_model_optimizer(args, data, model)
self.data = data
self.train_dataloader = data.dataloader.train_labeled_loader
self.eval_dataloader = data.dataloader.eval_loader
self.test_dataloader = data.dataloader.test_loader
self.loss_fct = loss_map[args.loss_fct]
if args.train:
self.weibull_model = None
self.train_results = []
else:
restore_model(self.model, args.model_output_dir)
def set_model_optimizer(self, args, data, model):
self.model = model.set_model(args, 'bert')
self.optimizer, self.scheduler = model.set_optimizer(self.model, data.dataloader.num_train_examples, args.train_batch_size, \
args.num_train_epochs, args.lr, args.warmup_proportion)
self.device = model.device
def train(self, args, data):
self.logger.info('Training Start...')
best_model = None
wait = 0
best_eval_score = 0
train_results = []
for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
self.model.train()
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
for step, batch in enumerate(tqdm(self.train_dataloader, desc="Iteration")):
batch = tuple(t.to(self.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
with torch.set_grad_enabled(True):
loss = self.model(input_ids, segment_ids, input_mask, label_ids, mode='train', loss_fct=self.loss_fct)
loss.backward()
tr_loss += loss.item()
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
loss = tr_loss / nb_tr_steps
y_true, y_pred = self.get_outputs(args, data, mode = 'eval')
eval_score = round(accuracy_score(y_true, y_pred) * 100, 2)
eval_results = {
'train_loss': loss,
'eval_score': eval_score,
'best_eval_score': best_eval_score,
}
train_results.append(eval_results)
self.logger.info("***** Epoch: %s: Eval results *****", str(epoch + 1))
for key in sorted(eval_results.keys()):
self.logger.info(" %s = %s", key, str(eval_results[key]))
if eval_score > best_eval_score:
best_model = copy.deepcopy(self.model)
wait = 0
best_eval_score = eval_score
elif eval_score > 0:
wait += 1
if wait >= args.wait_patient:
break
self.model = best_model
self.train_results = train_results
if args.save_model:
save_model(self.model, args.model_output_dir)
self.logger.info('Training finished...')
def get_outputs(self, args, data, mode = 'eval', get_feats = False, compute_centroids=False):
if mode == 'eval':
dataloader = self.eval_dataloader
elif mode == 'test':
dataloader = self.test_dataloader
elif mode == 'train':
dataloader = self.train_dataloader
self.model.eval()
total_labels = torch.empty(0,dtype=torch.long).to(self.device)
total_logits = torch.empty((0, data.num_labels)).to(self.device)
total_features = torch.empty((0,args.feat_dim)).to(self.device)
centroids = torch.zeros(data.num_labels, data.num_labels).to(self.device)
for batch in tqdm(dataloader, desc="Iteration"):
batch = tuple(t.to(self.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
with torch.set_grad_enabled(False):
pooled_output, logits = self.model(input_ids, segment_ids, input_mask)
total_labels = torch.cat((total_labels, label_ids))
total_logits = torch.cat((total_logits, logits))
total_features = torch.cat((total_features, pooled_output))
if compute_centroids:
for i in range(len(label_ids)):
centroids[label_ids[i]] += logits[i]
if get_feats:
feats = total_features.cpu().numpy()
return feats
else:
total_probs = F.softmax(total_logits.detach(), dim=1)
total_maxprobs, total_preds = total_probs.max(dim = 1)
y_probs = total_probs.cpu().numpy()
y_pred = total_preds.cpu().numpy()
y_true = total_labels.cpu().numpy()
y_logit = total_logits.cpu().numpy()
if compute_centroids:
centroids /= torch.tensor(self.class_count(y_true)).float().unsqueeze(1).to(self.device)
centroids = centroids.detach().cpu().numpy()
mean_vecs, dis_sorted = self.cal_vec_dis(args, data, centroids, y_logit, y_true)
weibull_model = weibull_tailfitting(mean_vecs, dis_sorted, data.num_labels, tailsize = args.weibull_tail_size)
return weibull_model
else:
if self.weibull_model is not None:
y_pred = self.classify_openmax(args, data, len(y_true), y_probs, y_logit)
return y_true, y_pred
def test(self, args, data, show = False):
self.weibull_model = self.get_outputs(args, data, mode = 'train', compute_centroids=True)
y_true, y_pred = self.get_outputs(args, data, mode = 'test')
cm = confusion_matrix(y_true,y_pred)
test_results = F_measure(cm)
acc = round(accuracy_score(y_true, y_pred) * 100, 2)
test_results['Acc'] = acc
self.logger.info
self.logger.info("***** Test: Confusion Matrix *****")
self.logger.info("%s", str(cm))
self.logger.info("***** Test results *****")
for key in sorted(test_results.keys()):
self.logger.info(" %s = %s", key, str(test_results[key]))
test_results['y_true'] = y_true
test_results['y_pred'] = y_pred
return test_results
def classify_openmax(self, args, data, num_samples, y_prob, y_logit):
y_preds = []
for i in range(num_samples):
textarr = {}
textarr['scores'] = y_prob[i]
textarr['fc8'] = y_logit[i]
openmax, softmax = recalibrate_scores(self.weibull_model, data.num_labels, textarr, \
alpharank=min(args.alpharank, data.num_labels))
openmax = np.array(openmax)
pred = np.argmax(openmax)
max_prob = max(openmax)
if max_prob < args.threshold:
pred = data.unseen_label_id
y_preds.append(pred)
return y_preds
def cal_vec_dis(self, args, data, centroids, y_logit, y_true):
mean_vectors = [x for x in centroids]
dis_all = []
for i in range(data.num_labels):
arr = y_logit[y_true == i]
dis_all.append(self.get_distances(args, arr, mean_vectors[i]))
dis_sorted = [sorted(x) for x in dis_all]
return mean_vectors, dis_sorted
def get_distances(self, args, arr, mav):
pre = []
for i in arr:
pre.append(compute_distance(i, mav, args.distance_type))
return pre
def class_count(self, labels):
class_data_num = []
for l in np.unique(labels):
num = len(labels[labels == l])
class_data_num.append(num)
return class_data_num