File size: 1,815 Bytes
35839a1
 
 
 
 
8a6ed33
 
35839a1
 
 
 
 
 
 
 
 
 
 
 
8a6ed33
35839a1
 
 
 
 
 
 
 
 
8a6ed33
 
 
 
 
 
35839a1
 
 
 
8a6ed33
35839a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a6ed33
35839a1
 
 
 
 
 
 
 
 
8a6ed33
 
35839a1
 
8a6ed33
35839a1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from utils import *
from torch.utils.data import DataLoader
from models import *
from tqdm.auto import tqdm
import os
import torch.nn.functional as F

timesteps = 500
beta1 = 1e-4
beta2 = 0.02

device = "cuda"
n_feat = 64
n_cfeat = 5
height = 16
save_dir="./checkpoints"

batch_size = 100
n_epoch = 60
lrate = 1e-3


b_t = (beta2 - beta1) * torch.linspace(0,1,timesteps+1,device=device) + beta1
a_t = 1 - b_t
a_bt = torch.cumsum(a_t.log(),0).exp()
a_bt[0] = 1



transform = transforms.Compose([
    transforms.ToTensor(),                # from [0,255] to range [0.0,1.0]
    transforms.Normalize((0.5,), (0.5,))  # range [-1,1]

])
dataset = CustomDataset("./sprites_1788_16x16.npy", "./sprite_labels_nc_1788_16x16.npy", transform, null_context=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)


nn_model = ContextUnet(3,n_feat,n_cfeat,height).to(device)
optim = torch.optim.Adam(nn_model.parameters(),lrate)

def perturb_input(x, t, noise):
    return a_bt.sqrt()[t, None, None, None] * x + (1 - a_bt[t, None, None, None]) * noise


nn_model.train()

for epoch in range(n_epoch):

    optim.param_groups[0]['lr'] = lrate * (1-epoch/n_epoch)
    for x,_ in tqdm(dataloader):
        optim.zero_grad()

        x = x.to(device)

        t = torch.randint(1,timesteps+1,(x.shape[0],)).to(device)
        noise = torch.randn_like(x)
        x_pert = perturb_input(x,t,noise)

        pred = nn_model(x_pert,t / timesteps)

        loss = F.mse_loss(pred,noise)
        loss.backward()
        optim.step()

    if epoch % 10 == 0 and epoch > 0:
        print(f"Epoch: {epoch} | Loss: {loss}")
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        torch.save(nn_model,save_dir + f"/model_Epoch{epoch}.pth")
        print("Saved model")