File size: 1,134 Bytes
2df2deb
6f405dd
2df2deb
7c818c6
4349b9b
0049910
 
 
4349b9b
fa99492
301a9d0
80437f3
8f0b92e
 
 
 
95c34a7
009b27a
8a73975
1582608
 
 
8f0b92e
 
76bd6b1
dcd6392
 
d0b9a56
8f0b92e
 
 
db3e9ba
8f0b92e
d2803c8
8f0b92e
d2803c8
 
 
80437f3
d2803c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

from util import Config





class Trainer:
    def __init__(self, config: Config):
        self.__dict__ = dict(config.__dict__)
        

    def log(self, loss: float):
        print(f"Epoch: {self.epoch} / {self.num_epochs}\t\tBatch: {self.batch} / {self.num_batches}\t\tLoss: {round(loss, 4)}")
        
        args = {'epoch': self.epoch, 'batch': self.batch, 'loss': loss}
        self.wandb(args)

        if self.inference.frequency != 0:
            if self.batch % self.inference.frequency == 0:
                print(f'{self.model.generate_text(self.inference.seed_text, self.inference.n_predict)}')


    def train(self, batches):
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)

        self.model.unfreeze()

        for self.epoch in range(self.num_epochs):
            for self.batch in range(self.num_batches):
                ids = batches[self.batch]

                loss = self.model.compute_loss(ids)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                self.log(loss.item())