|
|
import os
|
|
|
import torch
|
|
|
import platform
|
|
|
from time import time
|
|
|
from tqdm import tqdm
|
|
|
from trainer import train
|
|
|
import torch.optim as optim
|
|
|
from utils.parser import parse_args
|
|
|
from utils.logger import init_logger
|
|
|
from utils.configurator import Config
|
|
|
from torch.utils.data import DataLoader
|
|
|
from tensorboardX import SummaryWriter
|
|
|
from utils.evaluator import evaluate_model
|
|
|
from utils.data_loader import Load_dataset, Load_eval_dataset
|
|
|
from utils.helper import early_stopping, plot_curve, res_output, stop_log, update_result, sele_para
|
|
|
from model import REARM
|
|
|
|
|
|
|
|
|
class Net:
|
|
|
def __init__(self, args):
|
|
|
|
|
|
self.config = Config(args)
|
|
|
|
|
|
self.logger = init_logger(self.config)
|
|
|
self.logger.info(self.config)
|
|
|
self.logger.info('██Server: \t' + platform.node())
|
|
|
self.logger.info('██Dir: \t' + os.getcwd() + '\n')
|
|
|
self.device = self.config.device
|
|
|
self.model_name = self.config.model_name
|
|
|
self.dataset_name = self.config.dataset
|
|
|
self.batch_size = self.config.batch_size
|
|
|
self.num_workers = self.config.num_workers
|
|
|
self.learning_rate = self.config.learning_rate
|
|
|
self.num_epoch = self.config.num_epoch
|
|
|
self.topk = self.config.topk
|
|
|
self.metrics = self.config.metrics
|
|
|
self.valid_metric = self.config.valid_metric
|
|
|
self.stopping_step = self.config.stopping_step
|
|
|
self.reg_weight = self.config.reg_weight
|
|
|
self.cur_step = 0
|
|
|
self.best_valid_score = -1
|
|
|
self.best_valid_result = {}
|
|
|
self.best_test_upon_valid = {}
|
|
|
|
|
|
self.writer = SummaryWriter() if self.config.writer else None
|
|
|
|
|
|
|
|
|
Dataset = Load_dataset(self.config)
|
|
|
valid_dataset, test_dataset = Dataset.load_eval_data()
|
|
|
self.train_data = DataLoader(Dataset, batch_size=self.batch_size, shuffle=True,
|
|
|
num_workers=self.num_workers)
|
|
|
|
|
|
(self.valid_data, self.test_data) = (Load_eval_dataset("Validation", self.config, valid_dataset),
|
|
|
Load_eval_dataset("Testing", self.config, test_dataset))
|
|
|
self.model = REARM(self.config, Dataset).to(self.device)
|
|
|
self.optimizer = optim.AdamW(self.model.parameters(), self.learning_rate, weight_decay=self.reg_weight)
|
|
|
lr_scheduler = self.config.learning_rate_scheduler
|
|
|
fac = lambda epoch: lr_scheduler[0] ** (epoch / lr_scheduler[1])
|
|
|
scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=fac)
|
|
|
self.lr_scheduler = scheduler
|
|
|
self.logger.info(self.model)
|
|
|
|
|
|
def plot_train_loss(self):
|
|
|
plot_curve(self)
|
|
|
|
|
|
def run(self):
|
|
|
run_start_time = time()
|
|
|
for epoch_idx in tqdm(range(self.num_epoch)):
|
|
|
train_start_time = time()
|
|
|
train_loss = train(self, epoch_idx)
|
|
|
|
|
|
if torch.isnan(train_loss[0]):
|
|
|
ret_value = {"Recall@20": -1} if self.best_test_upon_valid == {} else self.best_test_upon_valid
|
|
|
stop_output = '\n ' + str(self.config.dataset) + ' key parameter: ' + sele_para(self.config)
|
|
|
self.logger.info(stop_output)
|
|
|
self.logger.info('Loss is nan at epoch: {}; last value is {}Exiting.'.format(epoch_idx, ret_value))
|
|
|
return ret_value
|
|
|
|
|
|
self.lr_scheduler.step()
|
|
|
|
|
|
train_output = res_output(epoch_idx, train_start_time, time(), train_loss, "train")
|
|
|
self.logger.info(train_output)
|
|
|
|
|
|
|
|
|
valid_start_time = time()
|
|
|
valid_score, valid_result = evaluate_model(self, epoch_idx, self.valid_data, t_or_v="valid")
|
|
|
|
|
|
self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping(
|
|
|
valid_score, self.best_valid_score, self.cur_step, self.stopping_step)
|
|
|
|
|
|
self.best_valid_result[epoch_idx] = self.best_valid_score
|
|
|
valid_output = res_output(epoch_idx, valid_start_time, time(), valid_result, t_or_v="valid")
|
|
|
self.logger.info(valid_output)
|
|
|
|
|
|
if update_flag:
|
|
|
|
|
|
test_start_time = time()
|
|
|
_, test_result = evaluate_model(self, epoch_idx, self.test_data, t_or_v="test")
|
|
|
test_score_output = res_output(epoch_idx, test_start_time, time(), test_result, t_or_v="test")
|
|
|
self.logger.info(test_score_output)
|
|
|
update_result(self, test_result)
|
|
|
|
|
|
if stop_flag:
|
|
|
stop_log(self, epoch_idx, run_start_time)
|
|
|
break
|
|
|
else:
|
|
|
print('patience ==> %d' % (self.stopping_step - self.cur_step))
|
|
|
return self.best_test_upon_valid
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
_args = parse_args()
|
|
|
model = Net(_args)
|
|
|
best_score = model.run()
|
|
|
|