import os import gc import torch import tqdm import pandas as pd import seaborn as sns import matplotlib.pyplot as plt from torch.utils.data import DataLoader from bert_pytorch.model import BERT from bert_pytorch.trainer import BERTTrainer from bert_pytorch.dataset import LogDataset, WordVocab from bert_pytorch.dataset.sample import generate_train_valid from bert_pytorch.dataset.utils import save_parameters class Trainer(): def __init__(self, options): self.device = options["device"] self.model_dir = options["model_dir"] self.model_path = options["model_path"] self.vocab_path = options["vocab_path"] self.output_path = options["output_dir"] self.window_size = options["window_size"] self.adaptive_window = options["adaptive_window"] self.sample_ratio = options["train_ratio"] self.valid_ratio = options["valid_ratio"] self.seq_len = options["seq_len"] self.max_len = options["max_len"] self.corpus_lines = options["corpus_lines"] self.on_memory = options["on_memory"] self.batch_size = options["batch_size"] self.num_workers = options["num_workers"] self.lr = options["lr"] self.adam_beta1 = options["adam_beta1"] self.adam_beta2 = options["adam_beta2"] self.adam_weight_decay = options["adam_weight_decay"] self.with_cuda = options["with_cuda"] self.cuda_devices = options["cuda_devices"] self.log_freq = options["log_freq"] self.epochs = options["epochs"] self.hidden = options["hidden"] self.layers = options["layers"] self.attn_heads = options["attn_heads"] self.is_logkey = options["is_logkey"] self.is_time = options["is_time"] self.scale = options["scale"] self.scale_path = options["scale_path"] self.n_epochs_stop = options["n_epochs_stop"] self.hypersphere_loss = options["hypersphere_loss"] self.mask_ratio = options["mask_ratio"] self.min_len = options["min_len"] print("Save options parameters") save_parameters(options, self.model_dir + "parameters.txt") def train(self): print("Loading vocab", self.vocab_path) vocab = WordVocab.load_vocab(self.vocab_path) print("vocab Size: ", len(vocab)) print("\nLoading Train Dataset") train_file_path = os.path.join(self.output_path, "train") logkey_train, logkey_valid, time_train, time_valid = generate_train_valid( train_file_path, window_size=self.window_size, adaptive_window=self.adaptive_window, valid_size=self.valid_ratio, sample_ratio=self.sample_ratio, scale=self.scale, scale_path=self.scale_path, seq_len=self.seq_len, min_len=self.min_len ) train_dataset = LogDataset( logkey_train, time_train, vocab, seq_len=self.seq_len, corpus_lines=self.corpus_lines, on_memory=self.on_memory, mask_ratio=self.mask_ratio ) print("\nLoading valid Dataset") valid_dataset = LogDataset( logkey_valid, time_valid, vocab, seq_len=self.seq_len, on_memory=self.on_memory, mask_ratio=self.mask_ratio ) print("Creating Dataloader") self.train_data_loader = DataLoader( train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=train_dataset.collate_fn, drop_last=False ) self.valid_data_loader = DataLoader( valid_dataset, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=train_dataset.collate_fn, drop_last=False ) del train_dataset del valid_dataset del logkey_train del logkey_valid del time_train del time_valid gc.collect() print("Building BERT model") bert = BERT( len(vocab), max_len=self.max_len, hidden=self.hidden, n_layers=self.layers, attn_heads=self.attn_heads, is_logkey=self.is_logkey, is_time=self.is_time ) print("Creating BERT Trainer") self.trainer = BERTTrainer( bert, len(vocab), train_dataloader=self.train_data_loader, valid_dataloader=self.valid_data_loader, lr=self.lr, betas=(self.adam_beta1, self.adam_beta2), weight_decay=self.adam_weight_decay, with_cuda=self.with_cuda, cuda_devices=self.cuda_devices, log_freq=self.log_freq, is_logkey=self.is_logkey, is_time=self.is_time, hypersphere_loss=self.hypersphere_loss ) self.start_iteration(surfix_log="log2") self.plot_train_valid_loss("_log2") def start_iteration(self, surfix_log): print("Training Start") best_loss = float('inf') epochs_no_improve = 0 for epoch in range(self.epochs): print("\n") if self.hypersphere_loss: center = self.calculate_center([self.train_data_loader, self.valid_data_loader]) self.trainer.hyper_center = center _, train_dist = self.trainer.train(epoch) avg_loss, valid_dist = self.trainer.valid(epoch) self.trainer.save_log(self.model_dir, surfix_log) if self.hypersphere_loss: self.trainer.radius = self.trainer.get_radius(train_dist + valid_dist, self.trainer.nu) if avg_loss < best_loss: best_loss = avg_loss self.trainer.save(self.model_path) epochs_no_improve = 0 if epoch > 10 and self.hypersphere_loss: best_center = self.trainer.hyper_center best_radius = self.trainer.radius total_dist = train_dist + valid_dist if best_center is None: raise TypeError("center is None") print("best radius", best_radius) best_center_path = self.model_dir + "best_center.pt" print("Save best center", best_center_path) torch.save({"center": best_center, "radius": best_radius}, best_center_path) total_dist_path = self.model_dir + "best_total_dist.pt" print("save total dist: ", total_dist_path) torch.save(total_dist, total_dist_path) else: epochs_no_improve += 1 if epochs_no_improve == self.n_epochs_stop: print("Early stopping") break def calculate_center(self, data_loader_list): print("start calculate center") with torch.no_grad(): outputs = 0 total_samples = 0 for data_loader in data_loader_list: totol_length = len(data_loader) data_iter = tqdm.tqdm(enumerate(data_loader), total=totol_length) for i, data in data_iter: data = {key: value.to(self.device) for key, value in data.items()} result = self.trainer.model.forward(data["bert_input"], data["time_input"]) cls_output = result["cls_output"] outputs += torch.sum(cls_output.detach().clone(), dim=0) total_samples += cls_output.size(0) center = outputs / total_samples return center def plot_train_valid_loss(self, surfix_log): train_loss = pd.read_csv(self.model_dir + f"train{surfix_log}.csv") valid_loss = pd.read_csv(self.model_dir + f"valid{surfix_log}.csv") sns.lineplot(x="epoch", y="loss", data=train_loss, label="train loss") sns.lineplot(x="epoch", y="loss", data=valid_loss, label="valid loss") plt.title("epoch vs train loss vs valid loss") plt.legend() plt.savefig(self.model_dir + "train_valid_loss.png") plt.show() print("plot done")