REARM / main.py
MrShouxingMa's picture
Upload 19 files
f60c555 verified
import os
import torch
import platform
from time import time
from tqdm import tqdm
from trainer import train
import torch.optim as optim
from utils.parser import parse_args
from utils.logger import init_logger
from utils.configurator import Config
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from utils.evaluator import evaluate_model
from utils.data_loader import Load_dataset, Load_eval_dataset
from utils.helper import early_stopping, plot_curve, res_output, stop_log, update_result, sele_para
from model import REARM
class Net:
def __init__(self, args):
# Complete initialization of all parameters (including random seeds)
self.config = Config(args)
# Use logger
self.logger = init_logger(self.config)
self.logger.info(self.config)
self.logger.info('██Server: \t' + platform.node())
self.logger.info('██Dir: \t' + os.getcwd() + '\n')
self.device = self.config.device
self.model_name = self.config.model_name
self.dataset_name = self.config.dataset
self.batch_size = self.config.batch_size
self.num_workers = self.config.num_workers
self.learning_rate = self.config.learning_rate
self.num_epoch = self.config.num_epoch
self.topk = self.config.topk
self.metrics = self.config.metrics
self.valid_metric = self.config.valid_metric
self.stopping_step = self.config.stopping_step
self.reg_weight = self.config.reg_weight
self.cur_step = 0
self.best_valid_score = -1
self.best_valid_result = {}
self.best_test_upon_valid = {}
# Writer will output to ./runs/ directory by default
self.writer = SummaryWriter() if self.config.writer else None
# Perform experimental configurations
Dataset = Load_dataset(self.config)
valid_dataset, test_dataset = Dataset.load_eval_data()
self.train_data = DataLoader(Dataset, batch_size=self.batch_size, shuffle=True,
num_workers=self.num_workers)
(self.valid_data, self.test_data) = (Load_eval_dataset("Validation", self.config, valid_dataset),
Load_eval_dataset("Testing", self.config, test_dataset))
self.model = REARM(self.config, Dataset).to(self.device)
self.optimizer = optim.AdamW(self.model.parameters(), self.learning_rate, weight_decay=self.reg_weight)
lr_scheduler = self.config.learning_rate_scheduler
fac = lambda epoch: lr_scheduler[0] ** (epoch / lr_scheduler[1])
scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=fac)
self.lr_scheduler = scheduler
self.logger.info(self.model)
def plot_train_loss(self):
plot_curve(self)
def run(self):
run_start_time = time()
for epoch_idx in tqdm(range(self.num_epoch)):
train_start_time = time()
train_loss = train(self, epoch_idx)
# Save if an exception occurs
if torch.isnan(train_loss[0]):
ret_value = {"Recall@20": -1} if self.best_test_upon_valid == {} else self.best_test_upon_valid
stop_output = '\n ' + str(self.config.dataset) + ' key parameter: ' + sele_para(self.config)
self.logger.info(stop_output)
self.logger.info('Loss is nan at epoch: {}; last value is {}Exiting.'.format(epoch_idx, ret_value))
return ret_value
self.lr_scheduler.step()
train_output = res_output(epoch_idx, train_start_time, time(), train_loss, "train")
self.logger.info(train_output)
# valid evaluate_model
valid_start_time = time()
valid_score, valid_result = evaluate_model(self, epoch_idx, self.valid_data, t_or_v="valid")
self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping(
valid_score, self.best_valid_score, self.cur_step, self.stopping_step)
self.best_valid_result[epoch_idx] = self.best_valid_score
valid_output = res_output(epoch_idx, valid_start_time, time(), valid_result, t_or_v="valid")
self.logger.info(valid_output)
if update_flag:
# test evaluate_model
test_start_time = time()
_, test_result = evaluate_model(self, epoch_idx, self.test_data, t_or_v="test")
test_score_output = res_output(epoch_idx, test_start_time, time(), test_result, t_or_v="test")
self.logger.info(test_score_output)
update_result(self, test_result)
if stop_flag:
stop_log(self, epoch_idx, run_start_time)
break
else:
print('patience ==> %d' % (self.stopping_step - self.cur_step))
return self.best_test_upon_valid
if __name__ == '__main__':
_args = parse_args()
model = Net(_args)
best_score = model.run()