import torch from utils import * from torch.utils.data import DataLoader from models import * from tqdm.auto import tqdm import os import torch.nn.functional as F timesteps = 500 beta1 = 1e-4 beta2 = 0.02 device = "cuda" n_feat = 64 n_cfeat = 5 height = 16 save_dir="./checkpoints" batch_size = 100 n_epoch = 60 lrate = 1e-3 b_t = (beta2 - beta1) * torch.linspace(0,1,timesteps+1,device=device) + beta1 a_t = 1 - b_t a_bt = torch.cumsum(a_t.log(),0).exp() a_bt[0] = 1 transform = transforms.Compose([ transforms.ToTensor(), # from [0,255] to range [0.0,1.0] transforms.Normalize((0.5,), (0.5,)) # range [-1,1] ]) dataset = CustomDataset("./sprites_1788_16x16.npy", "./sprite_labels_nc_1788_16x16.npy", transform, null_context=False) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1) nn_model = ContextUnet(3,n_feat,n_cfeat,height).to(device) optim = torch.optim.Adam(nn_model.parameters(),lrate) def perturb_input(x, t, noise): return a_bt.sqrt()[t, None, None, None] * x + (1 - a_bt[t, None, None, None]) * noise nn_model.train() for epoch in range(n_epoch): optim.param_groups[0]['lr'] = lrate * (1-epoch/n_epoch) for x,_ in tqdm(dataloader): optim.zero_grad() x = x.to(device) t = torch.randint(1,timesteps+1,(x.shape[0],)).to(device) noise = torch.randn_like(x) x_pert = perturb_input(x,t,noise) pred = nn_model(x_pert,t / timesteps) loss = F.mse_loss(pred,noise) loss.backward() optim.step() if epoch % 10 == 0 and epoch > 0: print(f"Epoch: {epoch} | Loss: {loss}") if not os.path.exists(save_dir): os.mkdir(save_dir) torch.save(nn_model,save_dir + f"/model_Epoch{epoch}.pth") print("Saved model")