File size: 1,134 Bytes
2df2deb 6f405dd 2df2deb 7c818c6 4349b9b 0049910 4349b9b fa99492 301a9d0 80437f3 8f0b92e 95c34a7 009b27a 8a73975 1582608 8f0b92e 76bd6b1 dcd6392 d0b9a56 8f0b92e db3e9ba 8f0b92e d2803c8 8f0b92e d2803c8 80437f3 d2803c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import torch
from util import Config
class Trainer:
def __init__(self, config: Config):
self.__dict__ = dict(config.__dict__)
def log(self, loss: float):
print(f"Epoch: {self.epoch} / {self.num_epochs}\t\tBatch: {self.batch} / {self.num_batches}\t\tLoss: {round(loss, 4)}")
args = {'epoch': self.epoch, 'batch': self.batch, 'loss': loss}
self.wandb(args)
if self.inference.frequency != 0:
if self.batch % self.inference.frequency == 0:
print(f'{self.model.generate_text(self.inference.seed_text, self.inference.n_predict)}')
def train(self, batches):
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
self.model.unfreeze()
for self.epoch in range(self.num_epochs):
for self.batch in range(self.num_batches):
ids = batches[self.batch]
loss = self.model.compute_loss(ids)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.log(loss.item())
|