File size: 1,815 Bytes
35839a1 8a6ed33 35839a1 8a6ed33 35839a1 8a6ed33 35839a1 8a6ed33 35839a1 8a6ed33 35839a1 8a6ed33 35839a1 8a6ed33 35839a1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | import torch
from utils import *
from torch.utils.data import DataLoader
from models import *
from tqdm.auto import tqdm
import os
import torch.nn.functional as F
timesteps = 500
beta1 = 1e-4
beta2 = 0.02
device = "cuda"
n_feat = 64
n_cfeat = 5
height = 16
save_dir="./checkpoints"
batch_size = 100
n_epoch = 60
lrate = 1e-3
b_t = (beta2 - beta1) * torch.linspace(0,1,timesteps+1,device=device) + beta1
a_t = 1 - b_t
a_bt = torch.cumsum(a_t.log(),0).exp()
a_bt[0] = 1
transform = transforms.Compose([
transforms.ToTensor(), # from [0,255] to range [0.0,1.0]
transforms.Normalize((0.5,), (0.5,)) # range [-1,1]
])
dataset = CustomDataset("./sprites_1788_16x16.npy", "./sprite_labels_nc_1788_16x16.npy", transform, null_context=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)
nn_model = ContextUnet(3,n_feat,n_cfeat,height).to(device)
optim = torch.optim.Adam(nn_model.parameters(),lrate)
def perturb_input(x, t, noise):
return a_bt.sqrt()[t, None, None, None] * x + (1 - a_bt[t, None, None, None]) * noise
nn_model.train()
for epoch in range(n_epoch):
optim.param_groups[0]['lr'] = lrate * (1-epoch/n_epoch)
for x,_ in tqdm(dataloader):
optim.zero_grad()
x = x.to(device)
t = torch.randint(1,timesteps+1,(x.shape[0],)).to(device)
noise = torch.randn_like(x)
x_pert = perturb_input(x,t,noise)
pred = nn_model(x_pert,t / timesteps)
loss = F.mse_loss(pred,noise)
loss.backward()
optim.step()
if epoch % 10 == 0 and epoch > 0:
print(f"Epoch: {epoch} | Loss: {loss}")
if not os.path.exists(save_dir):
os.mkdir(save_dir)
torch.save(nn_model,save_dir + f"/model_Epoch{epoch}.pth")
print("Saved model")
|