|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from tqdm import tqdm
|
|
|
from torch.nn.utils.clip_grad import clip_grad_norm_
|
|
|
|
|
|
|
|
|
|
|
|
def train(self, epoch_idx):
|
|
|
self.model.train()
|
|
|
sum_loss = 0.0
|
|
|
sum_bpr_loss, sum_reg_loss = 0.0, 0.0
|
|
|
sum_diff_loss, sum_mul_vt_cl_loss = 0.0, 0.0
|
|
|
|
|
|
step = 0.0
|
|
|
|
|
|
|
|
|
for batch_idx, interactions in enumerate(self.train_data):
|
|
|
self.optimizer.zero_grad()
|
|
|
loss, bpr_loss, reg_loss, mul_vt_cl_loss, diff_loss = self.model.loss(interactions[0],
|
|
|
interactions[1])
|
|
|
if torch.isnan(loss):
|
|
|
self.logger.info('Loss is nan at epoch: {}, batch index: {}. Exiting.'.format(epoch_idx, batch_idx))
|
|
|
return loss, torch.tensor(0.0)
|
|
|
|
|
|
loss.backward()
|
|
|
self.optimizer.step()
|
|
|
|
|
|
step += 1.0
|
|
|
sum_loss += loss
|
|
|
sum_bpr_loss += bpr_loss
|
|
|
sum_reg_loss += reg_loss
|
|
|
sum_mul_vt_cl_loss += mul_vt_cl_loss
|
|
|
sum_diff_loss += diff_loss
|
|
|
mean_loss = sum_loss / step
|
|
|
mean_bpr_loss = sum_bpr_loss / step
|
|
|
mean_reg_loss = sum_reg_loss / step
|
|
|
mean_mul_vt_cl_loss = sum_mul_vt_cl_loss / step
|
|
|
mean_diff_loss = sum_diff_loss / step
|
|
|
|
|
|
if self.writer is not None:
|
|
|
self.writer.add_scalar('loss/train', mean_loss, epoch_idx)
|
|
|
self.writer.add_scalar('loss/bpr_loss', mean_bpr_loss, epoch_idx)
|
|
|
self.writer.add_scalar('loss/reg_loss', mean_reg_loss, epoch_idx)
|
|
|
self.writer.add_scalar('loss/mul_vt_cl_loss', mean_mul_vt_cl_loss, epoch_idx)
|
|
|
self.writer.add_scalar('loss/diff_loss', mean_diff_loss, epoch_idx)
|
|
|
|
|
|
|
|
|
return [mean_loss, mean_bpr_loss, mean_reg_loss, mean_mul_vt_cl_loss, mean_diff_loss]
|
|
|
|