File size: 5,156 Bytes
f60c555 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import os
import torch
import platform
from time import time
from tqdm import tqdm
from trainer import train
import torch.optim as optim
from utils.parser import parse_args
from utils.logger import init_logger
from utils.configurator import Config
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from utils.evaluator import evaluate_model
from utils.data_loader import Load_dataset, Load_eval_dataset
from utils.helper import early_stopping, plot_curve, res_output, stop_log, update_result, sele_para
from model import REARM
class Net:
def __init__(self, args):
# Complete initialization of all parameters (including random seeds)
self.config = Config(args)
# Use logger
self.logger = init_logger(self.config)
self.logger.info(self.config)
self.logger.info('██Server: \t' + platform.node())
self.logger.info('██Dir: \t' + os.getcwd() + '\n')
self.device = self.config.device
self.model_name = self.config.model_name
self.dataset_name = self.config.dataset
self.batch_size = self.config.batch_size
self.num_workers = self.config.num_workers
self.learning_rate = self.config.learning_rate
self.num_epoch = self.config.num_epoch
self.topk = self.config.topk
self.metrics = self.config.metrics
self.valid_metric = self.config.valid_metric
self.stopping_step = self.config.stopping_step
self.reg_weight = self.config.reg_weight
self.cur_step = 0
self.best_valid_score = -1
self.best_valid_result = {}
self.best_test_upon_valid = {}
# Writer will output to ./runs/ directory by default
self.writer = SummaryWriter() if self.config.writer else None
# Perform experimental configurations
Dataset = Load_dataset(self.config)
valid_dataset, test_dataset = Dataset.load_eval_data()
self.train_data = DataLoader(Dataset, batch_size=self.batch_size, shuffle=True,
num_workers=self.num_workers)
(self.valid_data, self.test_data) = (Load_eval_dataset("Validation", self.config, valid_dataset),
Load_eval_dataset("Testing", self.config, test_dataset))
self.model = REARM(self.config, Dataset).to(self.device)
self.optimizer = optim.AdamW(self.model.parameters(), self.learning_rate, weight_decay=self.reg_weight)
lr_scheduler = self.config.learning_rate_scheduler
fac = lambda epoch: lr_scheduler[0] ** (epoch / lr_scheduler[1])
scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=fac)
self.lr_scheduler = scheduler
self.logger.info(self.model)
def plot_train_loss(self):
plot_curve(self)
def run(self):
run_start_time = time()
for epoch_idx in tqdm(range(self.num_epoch)):
train_start_time = time()
train_loss = train(self, epoch_idx)
# Save if an exception occurs
if torch.isnan(train_loss[0]):
ret_value = {"Recall@20": -1} if self.best_test_upon_valid == {} else self.best_test_upon_valid
stop_output = '\n ' + str(self.config.dataset) + ' key parameter: ' + sele_para(self.config)
self.logger.info(stop_output)
self.logger.info('Loss is nan at epoch: {}; last value is {}Exiting.'.format(epoch_idx, ret_value))
return ret_value
self.lr_scheduler.step()
train_output = res_output(epoch_idx, train_start_time, time(), train_loss, "train")
self.logger.info(train_output)
# valid evaluate_model
valid_start_time = time()
valid_score, valid_result = evaluate_model(self, epoch_idx, self.valid_data, t_or_v="valid")
self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping(
valid_score, self.best_valid_score, self.cur_step, self.stopping_step)
self.best_valid_result[epoch_idx] = self.best_valid_score
valid_output = res_output(epoch_idx, valid_start_time, time(), valid_result, t_or_v="valid")
self.logger.info(valid_output)
if update_flag:
# test evaluate_model
test_start_time = time()
_, test_result = evaluate_model(self, epoch_idx, self.test_data, t_or_v="test")
test_score_output = res_output(epoch_idx, test_start_time, time(), test_result, t_or_v="test")
self.logger.info(test_score_output)
update_result(self, test_result)
if stop_flag:
stop_log(self, epoch_idx, run_start_time)
break
else:
print('patience ==> %d' % (self.stopping_step - self.cur_step))
return self.best_test_upon_valid
if __name__ == '__main__':
_args = parse_args()
model = Net(_args)
best_score = model.run()
|