THU-IAR's picture
Upload 198 files
2d06dcc verified
import torch
import torch.nn.functional as F
import numpy as np
import os
import copy
import logging
import pandas as pd
from torch import nn
from datetime import datetime
from sklearn.metrics import confusion_matrix, accuracy_score
from tqdm import trange, tqdm
from utils.functions import save_model
from utils.metrics import F_measure
from utils.functions import restore_model
from losses import loss_map
from sklearn.neighbors import LocalOutlierFactor
class SEGManager:
def __init__(self, args, data, model, logger_name = 'Detection'):
self.logger = logging.getLogger(logger_name)
self.set_model_optimizer(args, data, model)
self.data = data
self.train_dataloader = data.dataloader.train_labeled_loader
self.eval_dataloader = data.dataloader.eval_loader
self.test_dataloader = data.dataloader.test_loader
if args.train:
self.best_features = None
else:
restore_model(self.model, args.model_output_dir)
self.best_features = np.load(os.path.join(args.method_output_dir, 'features.npy'))
def set_model_optimizer(self, args, data, model):
self.model = model.set_model(args, 'bert')
self.optimizer, self.scheduler = model.set_optimizer(self.model, data.dataloader.num_train_examples, args.train_batch_size, \
args.num_train_epochs, args.lr, args.warmup_proportion)
self.device = model.device
def get_class_feats(self, args, data):
from dataloaders.bert_loader import convert_examples_to_features, InputExample
from transformers import BertTokenizer
known_labels = data.known_label_list
examples = []
for i, label in enumerate(known_labels):
if args.dataset == 'stackoverflow':
label = label.replace('-', ' ')
else:
label = label.replace('_', ' ')
guid = "label-%s" % i
examples.append(InputExample(guid=guid, text_a=label, text_b=None, label=None))
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
max_label_length = max([len(label.replace('_', ' ').split()) for label in known_labels]) + 2
if args.dataset == 'stackoverflow':
max_label_length = max([len(tokenizer.tokenize(label.replace('-', ' '))) for label in known_labels]) + 2
features = convert_examples_to_features(examples, None, max_label_length, tokenizer)
input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
class_feats = tuple((input_ids, segment_ids, input_mask))
return class_feats
def train(self, args, data):
train_labels = [example.label for example in data.dataloader.train_labeled_examples]
self.p_y = torch.tensor(np.unique(train_labels, return_counts=True)[1] / data.dataloader.num_train_examples)
self.logger.info("Priori probability of each class = %s", self.p_y.numpy())
if args.class_emb:
class_feats = self.get_class_feats(args, data)
self.class_feats = tuple(t.to(self.device) for t in class_feats)
class_ids, class_segment, class_mask = self.class_feats
best_model = None
best_eval_score = 0
wait = 0
for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
self.model.train()
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
for step, batch in enumerate(tqdm(self.train_dataloader, desc="Iteration")):
batch = tuple(t.to(self.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
with torch.set_grad_enabled(False):
class_emb = self.model(class_ids, class_segment, class_mask, feature_ext=True) if args.class_emb else None
with torch.set_grad_enabled(True):
loss = self.model(input_ids, segment_ids, input_mask, label_ids, mode='train', device=self.device, class_emb=class_emb, p_y = self.p_y)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.scheduler.step()
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
loss = tr_loss / nb_tr_steps
y_true, y_pred = self.get_outputs(args, data, self.eval_dataloader)
eval_score = round(accuracy_score(y_true, y_pred) * 100, 2)
eval_results = {
'train_loss': loss,
'eval_acc': eval_score,
'best_acc':best_eval_score,
}
self.logger.info("***** Epoch: %s: Eval results *****", str(epoch + 1))
for key in sorted(eval_results.keys()):
self.logger.info(" %s = %s", key, str(eval_results[key]))
if eval_score >= best_eval_score:
best_eval_score = eval_score
best_model = copy.deepcopy(self.model)
wait = 0
else:
wait += 1
if wait >= args.wait_patient:
break
self.model = best_model
if args.save_model:
save_model(self.model, args.model_output_dir)
def classify_lof(self, data, preds, train_feats, pred_feats):
lof = LocalOutlierFactor(n_neighbors=20, contamination = 0.05, novelty=True, n_jobs=-1)
lof.fit(train_feats)
y_pred_lof = pd.Series(lof.predict(pred_feats))
preds[y_pred_lof[y_pred_lof == -1].index] = data.unseen_label_id
return preds
def get_outputs(self, args, data, dataloader, get_feats = False, train_feats = None):
self.model.eval()
total_labels = torch.empty(0,dtype=torch.long).to(self.device)
total_logits = torch.empty((0, data.num_labels)).to(self.device)
total_features = torch.empty((0,args.feat_dim)).to(self.device)
for batch in tqdm(dataloader, desc="Iteration"):
batch = tuple(t.to(self.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
with torch.set_grad_enabled(False):
pooled_output, logits = self.model(input_ids, segment_ids, input_mask, p_y = self.p_y, device = self.device)
total_labels = torch.cat((total_labels,label_ids))
total_logits = torch.cat((total_logits, logits))
total_features = torch.cat((total_features, pooled_output))
if get_feats:
feats = total_features.cpu().numpy()
return feats
else:
total_preds = torch.argmax(total_logits.detach(), dim=1)
y_pred = total_preds.cpu().numpy()
y_true = total_labels.cpu().numpy()
if train_feats is not None:
feats = total_features.cpu().numpy()
y_pred = self.classify_lof(data, y_pred, train_feats, feats)
return y_true, y_pred
def test(self, args, data, show=False):
train_feats = self.get_outputs(args, data, self.train_dataloader, get_feats = True)
y_true, y_pred = self.get_outputs(args, data, self.test_dataloader, train_feats = train_feats)
cm = confusion_matrix(y_true, y_pred)
test_results = F_measure(cm)
acc = round(accuracy_score(y_true, y_pred) * 100, 2)
test_results['Acc'] = acc
if show:
self.logger.info("***** Test: Confusion Matrix *****")
self.logger.info("%s", str(cm))
self.logger.info("***** Test results *****")
for key in sorted(test_results.keys()):
self.logger.info(" %s = %s", key, str(test_results[key]))
test_results['y_true'] = y_true
test_results['y_pred'] = y_pred
return test_results