File size: 1,986 Bytes
f60c555
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.nn.utils.clip_grad import clip_grad_norm_


# def train(length, epoch, dataloader, model, optimizer, batch_size, writer=None):
def train(self, epoch_idx):
    self.model.train()
    sum_loss = 0.0
    sum_bpr_loss, sum_reg_loss = 0.0, 0.0
    sum_diff_loss, sum_mul_vt_cl_loss = 0.0, 0.0

    step = 0.0
    # bar = tqdm(total=len(self.train_dataset))
    # num_bar = 0  self_vt_cl_loss, mul_vt_cl_loss
    for batch_idx, interactions in enumerate(self.train_data):
        self.optimizer.zero_grad()
        loss, bpr_loss, reg_loss, mul_vt_cl_loss, diff_loss = self.model.loss(interactions[0],
                                                                                                 interactions[1])
        if torch.isnan(loss):
            self.logger.info('Loss is nan at epoch: {}, batch index: {}. Exiting.'.format(epoch_idx, batch_idx))
            return loss, torch.tensor(0.0)

        loss.backward()
        self.optimizer.step() 

        step += 1.0
        sum_loss += loss
        sum_bpr_loss += bpr_loss
        sum_reg_loss += reg_loss
        sum_mul_vt_cl_loss += mul_vt_cl_loss
        sum_diff_loss += diff_loss
    mean_loss = sum_loss / step
    mean_bpr_loss = sum_bpr_loss / step
    mean_reg_loss = sum_reg_loss / step
    mean_mul_vt_cl_loss = sum_mul_vt_cl_loss / step
    mean_diff_loss = sum_diff_loss / step

    if self.writer is not None:
        self.writer.add_scalar('loss/train', mean_loss, epoch_idx)
        self.writer.add_scalar('loss/bpr_loss', mean_bpr_loss, epoch_idx)
        self.writer.add_scalar('loss/reg_loss', mean_reg_loss, epoch_idx)
        self.writer.add_scalar('loss/mul_vt_cl_loss', mean_mul_vt_cl_loss, epoch_idx)
        self.writer.add_scalar('loss/diff_loss', mean_diff_loss, epoch_idx)

    # bar.close()
    return [mean_loss, mean_bpr_loss, mean_reg_loss,  mean_mul_vt_cl_loss, mean_diff_loss]