Zilong-Zhao's picture
first commit
c4ac745
from copy import deepcopy
import torch
import numpy as np
from scripts.utils_train import update_ema
import pandas as pd
class Trainer:
def __init__(self, diffusion, train_iter, lr, weight_decay, steps, device=torch.device('cuda:1')):
self.diffusion = diffusion
self.ema_model = deepcopy(self.diffusion._denoise_fn)
for param in self.ema_model.parameters():
param.detach_()
self.train_iter = train_iter
self.steps = steps
self.init_lr = lr
self.optimizer = torch.optim.AdamW(self.diffusion.parameters(), lr=lr, weight_decay=weight_decay)
self.device = device
self.loss_history = pd.DataFrame(columns=['step', 'mloss', 'gloss', 'loss'])
self.log_every = 100
self.print_every = 500
self.ema_every = 1000
def _anneal_lr(self, step):
frac_done = step / self.steps
lr = self.init_lr * (1 - frac_done)
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
def _run_step(self, x, out_dict):
x = x.to(self.device)
for k in out_dict:
out_dict[k] = out_dict[k].long().to(self.device)
self.optimizer.zero_grad()
loss_multi, loss_gauss = self.diffusion.mixed_loss(x, out_dict)
loss = loss_multi + loss_gauss
loss.backward()
self.optimizer.step()
return loss_multi, loss_gauss
def run_loop(self):
step = 0
curr_loss_multi = 0.0
curr_loss_gauss = 0.0
curr_count = 0
while step < self.steps:
x, out_dict = next(self.train_iter)
out_dict = {'y': out_dict}
batch_loss_multi, batch_loss_gauss = self._run_step(x, out_dict)
self._anneal_lr(step)
curr_count += len(x)
curr_loss_multi += batch_loss_multi.item() * len(x)
curr_loss_gauss += batch_loss_gauss.item() * len(x)
if (step + 1) % self.log_every == 0:
mloss = np.around(curr_loss_multi / curr_count, 4)
gloss = np.around(curr_loss_gauss / curr_count, 4)
if (step + 1) % self.print_every == 0:
print(f'Step {(step + 1)}/{self.steps} MLoss: {mloss} GLoss: {gloss} Sum: {mloss + gloss}')
self.loss_history.loc[len(self.loss_history)] =[step + 1, mloss, gloss, mloss + gloss]
curr_count = 0
curr_loss_gauss = 0.0
curr_loss_multi = 0.0
update_ema(self.ema_model.parameters(), self.diffusion._denoise_fn.parameters())
step += 1