REARM / trainer.py
MrShouxingMa's picture
Upload 19 files
f60c555 verified
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.nn.utils.clip_grad import clip_grad_norm_
# def train(length, epoch, dataloader, model, optimizer, batch_size, writer=None):
def train(self, epoch_idx):
self.model.train()
sum_loss = 0.0
sum_bpr_loss, sum_reg_loss = 0.0, 0.0
sum_diff_loss, sum_mul_vt_cl_loss = 0.0, 0.0
step = 0.0
# bar = tqdm(total=len(self.train_dataset))
# num_bar = 0 self_vt_cl_loss, mul_vt_cl_loss
for batch_idx, interactions in enumerate(self.train_data):
self.optimizer.zero_grad()
loss, bpr_loss, reg_loss, mul_vt_cl_loss, diff_loss = self.model.loss(interactions[0],
interactions[1])
if torch.isnan(loss):
self.logger.info('Loss is nan at epoch: {}, batch index: {}. Exiting.'.format(epoch_idx, batch_idx))
return loss, torch.tensor(0.0)
loss.backward()
self.optimizer.step()
step += 1.0
sum_loss += loss
sum_bpr_loss += bpr_loss
sum_reg_loss += reg_loss
sum_mul_vt_cl_loss += mul_vt_cl_loss
sum_diff_loss += diff_loss
mean_loss = sum_loss / step
mean_bpr_loss = sum_bpr_loss / step
mean_reg_loss = sum_reg_loss / step
mean_mul_vt_cl_loss = sum_mul_vt_cl_loss / step
mean_diff_loss = sum_diff_loss / step
if self.writer is not None:
self.writer.add_scalar('loss/train', mean_loss, epoch_idx)
self.writer.add_scalar('loss/bpr_loss', mean_bpr_loss, epoch_idx)
self.writer.add_scalar('loss/reg_loss', mean_reg_loss, epoch_idx)
self.writer.add_scalar('loss/mul_vt_cl_loss', mean_mul_vt_cl_loss, epoch_idx)
self.writer.add_scalar('loss/diff_loss', mean_diff_loss, epoch_idx)
# bar.close()
return [mean_loss, mean_bpr_loss, mean_reg_loss, mean_mul_vt_cl_loss, mean_diff_loss]