| import torch |
| from utils import * |
| from torch.utils.data import DataLoader |
| from models import * |
| from tqdm.auto import tqdm |
| import os |
| import torch.nn.functional as F |
|
|
| timesteps = 500 |
| beta1 = 1e-4 |
| beta2 = 0.02 |
|
|
| device = "cuda" |
| n_feat = 64 |
| n_cfeat = 5 |
| height = 16 |
| save_dir="./checkpoints" |
|
|
| batch_size = 100 |
| n_epoch = 60 |
| lrate = 1e-3 |
|
|
|
|
| b_t = (beta2 - beta1) * torch.linspace(0,1,timesteps+1,device=device) + beta1 |
| a_t = 1 - b_t |
| a_bt = torch.cumsum(a_t.log(),0).exp() |
| a_bt[0] = 1 |
|
|
|
|
|
|
| transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize((0.5,), (0.5,)) |
|
|
| ]) |
| dataset = CustomDataset("./sprites_1788_16x16.npy", "./sprite_labels_nc_1788_16x16.npy", transform, null_context=False) |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1) |
|
|
|
|
| nn_model = ContextUnet(3,n_feat,n_cfeat,height).to(device) |
| optim = torch.optim.Adam(nn_model.parameters(),lrate) |
|
|
| def perturb_input(x, t, noise): |
| return a_bt.sqrt()[t, None, None, None] * x + (1 - a_bt[t, None, None, None]) * noise |
|
|
|
|
| nn_model.train() |
|
|
| for epoch in range(n_epoch): |
|
|
| optim.param_groups[0]['lr'] = lrate * (1-epoch/n_epoch) |
| for x,_ in tqdm(dataloader): |
| optim.zero_grad() |
|
|
| x = x.to(device) |
|
|
| t = torch.randint(1,timesteps+1,(x.shape[0],)).to(device) |
| noise = torch.randn_like(x) |
| x_pert = perturb_input(x,t,noise) |
|
|
| pred = nn_model(x_pert,t / timesteps) |
|
|
| loss = F.mse_loss(pred,noise) |
| loss.backward() |
| optim.step() |
|
|
| if epoch % 10 == 0 and epoch > 0: |
| print(f"Epoch: {epoch} | Loss: {loss}") |
| if not os.path.exists(save_dir): |
| os.mkdir(save_dir) |
| torch.save(nn_model,save_dir + f"/model_Epoch{epoch}.pth") |
| print("Saved model") |
|
|