Diffusion-Sprite / train.py
YashNagraj75's picture
Add checkpoints its still not clear
8a6ed33
import torch
from utils import *
from torch.utils.data import DataLoader
from models import *
from tqdm.auto import tqdm
import os
import torch.nn.functional as F
timesteps = 500
beta1 = 1e-4
beta2 = 0.02
device = "cuda"
n_feat = 64
n_cfeat = 5
height = 16
save_dir="./checkpoints"
batch_size = 100
n_epoch = 60
lrate = 1e-3
b_t = (beta2 - beta1) * torch.linspace(0,1,timesteps+1,device=device) + beta1
a_t = 1 - b_t
a_bt = torch.cumsum(a_t.log(),0).exp()
a_bt[0] = 1
transform = transforms.Compose([
transforms.ToTensor(), # from [0,255] to range [0.0,1.0]
transforms.Normalize((0.5,), (0.5,)) # range [-1,1]
])
dataset = CustomDataset("./sprites_1788_16x16.npy", "./sprite_labels_nc_1788_16x16.npy", transform, null_context=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)
nn_model = ContextUnet(3,n_feat,n_cfeat,height).to(device)
optim = torch.optim.Adam(nn_model.parameters(),lrate)
def perturb_input(x, t, noise):
return a_bt.sqrt()[t, None, None, None] * x + (1 - a_bt[t, None, None, None]) * noise
nn_model.train()
for epoch in range(n_epoch):
optim.param_groups[0]['lr'] = lrate * (1-epoch/n_epoch)
for x,_ in tqdm(dataloader):
optim.zero_grad()
x = x.to(device)
t = torch.randint(1,timesteps+1,(x.shape[0],)).to(device)
noise = torch.randn_like(x)
x_pert = perturb_input(x,t,noise)
pred = nn_model(x_pert,t / timesteps)
loss = F.mse_loss(pred,noise)
loss.backward()
optim.step()
if epoch % 10 == 0 and epoch > 0:
print(f"Epoch: {epoch} | Loss: {loss}")
if not os.path.exists(save_dir):
os.mkdir(save_dir)
torch.save(nn_model,save_dir + f"/model_Epoch{epoch}.pth")
print("Saved model")