import torch from util import Config class Trainer: def __init__(self, config: Config): self.__dict__ = dict(config.__dict__) def log(self, loss: float): print(f"Epoch: {self.epoch} / {self.num_epochs}\t\tBatch: {self.batch} / {self.num_batches}\t\tLoss: {round(loss, 4)}") args = {'epoch': self.epoch, 'batch': self.batch, 'loss': loss} self.wandb(args) if self.inference.frequency != 0: if self.batch % self.inference.frequency == 0: print(f'{self.model.generate_text(self.inference.seed_text, self.inference.n_predict)}') def train(self, batches): self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) self.model.unfreeze() for self.epoch in range(self.num_epochs): for self.batch in range(self.num_batches): ids = batches[self.batch] loss = self.model.compute_loss(ids) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.log(loss.item())