MukeshKapoor25's picture
changs
6f2ff70
import os
import gc
import torch
import tqdm
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from bert_pytorch.model import BERT
from bert_pytorch.trainer import BERTTrainer
from bert_pytorch.dataset import LogDataset, WordVocab
from bert_pytorch.dataset.sample import generate_train_valid
from bert_pytorch.dataset.utils import save_parameters
class Trainer():
def __init__(self, options):
self.device = options["device"]
self.model_dir = options["model_dir"]
self.model_path = options["model_path"]
self.vocab_path = options["vocab_path"]
self.output_path = options["output_dir"]
self.window_size = options["window_size"]
self.adaptive_window = options["adaptive_window"]
self.sample_ratio = options["train_ratio"]
self.valid_ratio = options["valid_ratio"]
self.seq_len = options["seq_len"]
self.max_len = options["max_len"]
self.corpus_lines = options["corpus_lines"]
self.on_memory = options["on_memory"]
self.batch_size = options["batch_size"]
self.num_workers = options["num_workers"]
self.lr = options["lr"]
self.adam_beta1 = options["adam_beta1"]
self.adam_beta2 = options["adam_beta2"]
self.adam_weight_decay = options["adam_weight_decay"]
self.with_cuda = options["with_cuda"]
self.cuda_devices = options["cuda_devices"]
self.log_freq = options["log_freq"]
self.epochs = options["epochs"]
self.hidden = options["hidden"]
self.layers = options["layers"]
self.attn_heads = options["attn_heads"]
self.is_logkey = options["is_logkey"]
self.is_time = options["is_time"]
self.scale = options["scale"]
self.scale_path = options["scale_path"]
self.n_epochs_stop = options["n_epochs_stop"]
self.hypersphere_loss = options["hypersphere_loss"]
self.mask_ratio = options["mask_ratio"]
self.min_len = options["min_len"]
print("Save options parameters")
save_parameters(options, self.model_dir + "parameters.txt")
def train(self):
print("Loading vocab", self.vocab_path)
vocab = WordVocab.load_vocab(self.vocab_path)
print("vocab Size: ", len(vocab))
print("\nLoading Train Dataset")
train_file_path = os.path.join(self.output_path, "train")
logkey_train, logkey_valid, time_train, time_valid = generate_train_valid(
train_file_path,
window_size=self.window_size,
adaptive_window=self.adaptive_window,
valid_size=self.valid_ratio,
sample_ratio=self.sample_ratio,
scale=self.scale,
scale_path=self.scale_path,
seq_len=self.seq_len,
min_len=self.min_len
)
train_dataset = LogDataset(
logkey_train, time_train, vocab,
seq_len=self.seq_len,
corpus_lines=self.corpus_lines,
on_memory=self.on_memory,
mask_ratio=self.mask_ratio
)
print("\nLoading valid Dataset")
valid_dataset = LogDataset(
logkey_valid, time_valid, vocab,
seq_len=self.seq_len,
on_memory=self.on_memory,
mask_ratio=self.mask_ratio
)
print("Creating Dataloader")
self.train_data_loader = DataLoader(
train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
collate_fn=train_dataset.collate_fn,
drop_last=False
)
self.valid_data_loader = DataLoader(
valid_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
collate_fn=train_dataset.collate_fn,
drop_last=False
)
del train_dataset
del valid_dataset
del logkey_train
del logkey_valid
del time_train
del time_valid
gc.collect()
print("Building BERT model")
bert = BERT(
len(vocab),
max_len=self.max_len,
hidden=self.hidden,
n_layers=self.layers,
attn_heads=self.attn_heads,
is_logkey=self.is_logkey,
is_time=self.is_time
)
print("Creating BERT Trainer")
self.trainer = BERTTrainer(
bert, len(vocab),
train_dataloader=self.train_data_loader,
valid_dataloader=self.valid_data_loader,
lr=self.lr,
betas=(self.adam_beta1, self.adam_beta2),
weight_decay=self.adam_weight_decay,
with_cuda=self.with_cuda,
cuda_devices=self.cuda_devices,
log_freq=self.log_freq,
is_logkey=self.is_logkey,
is_time=self.is_time,
hypersphere_loss=self.hypersphere_loss
)
self.start_iteration(surfix_log="log2")
self.plot_train_valid_loss("_log2")
def start_iteration(self, surfix_log):
print("Training Start")
best_loss = float('inf')
epochs_no_improve = 0
for epoch in range(self.epochs):
print("\n")
if self.hypersphere_loss:
center = self.calculate_center([self.train_data_loader, self.valid_data_loader])
self.trainer.hyper_center = center
_, train_dist = self.trainer.train(epoch)
avg_loss, valid_dist = self.trainer.valid(epoch)
self.trainer.save_log(self.model_dir, surfix_log)
if self.hypersphere_loss:
self.trainer.radius = self.trainer.get_radius(train_dist + valid_dist, self.trainer.nu)
if avg_loss < best_loss:
best_loss = avg_loss
self.trainer.save(self.model_path)
epochs_no_improve = 0
if epoch > 10 and self.hypersphere_loss:
best_center = self.trainer.hyper_center
best_radius = self.trainer.radius
total_dist = train_dist + valid_dist
if best_center is None:
raise TypeError("center is None")
print("best radius", best_radius)
best_center_path = self.model_dir + "best_center.pt"
print("Save best center", best_center_path)
torch.save({"center": best_center, "radius": best_radius}, best_center_path)
total_dist_path = self.model_dir + "best_total_dist.pt"
print("save total dist: ", total_dist_path)
torch.save(total_dist, total_dist_path)
else:
epochs_no_improve += 1
if epochs_no_improve == self.n_epochs_stop:
print("Early stopping")
break
def calculate_center(self, data_loader_list):
print("start calculate center")
with torch.no_grad():
outputs = 0
total_samples = 0
for data_loader in data_loader_list:
totol_length = len(data_loader)
data_iter = tqdm.tqdm(enumerate(data_loader), total=totol_length)
for i, data in data_iter:
data = {key: value.to(self.device) for key, value in data.items()}
result = self.trainer.model.forward(data["bert_input"], data["time_input"])
cls_output = result["cls_output"]
outputs += torch.sum(cls_output.detach().clone(), dim=0)
total_samples += cls_output.size(0)
center = outputs / total_samples
return center
def plot_train_valid_loss(self, surfix_log):
train_loss = pd.read_csv(self.model_dir + f"train{surfix_log}.csv")
valid_loss = pd.read_csv(self.model_dir + f"valid{surfix_log}.csv")
sns.lineplot(x="epoch", y="loss", data=train_loss, label="train loss")
sns.lineplot(x="epoch", y="loss", data=valid_loss, label="valid loss")
plt.title("epoch vs train loss vs valid loss")
plt.legend()
plt.savefig(self.model_dir + "train_valid_loss.png")
plt.show()
print("plot done")