Diffusion-CIFAR10 / train.py
Yash Nagraj
Dumb fix
b5ea21b
import os
from typing import Dict
import torch
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.utils import save_image
from diffusion import GaussianDiffusionSampler, GaussianDiffusionTrainer
from model import UNet
from scheduler import GradualWarmupScheduler
def train(modelConfig: Dict):
device = torch.device(modelConfig['device'])
dataset = CIFAR10(
"./",train=True,download=True,transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
)
dataloader = DataLoader(dataset,batch_size=modelConfig['batch_size'],shuffle=True, num_workers=4,drop_last=True,pin_memory=True)
net_model = UNet(modelConfig['T'],modelConfig['channel'],modelConfig['ch_mult'],modelConfig['attn'],
modelConfig['num_res_blocks'],modelConfig['dropout'])
optimizer = optim.AdamW(net_model.parameters(),modelConfig['lr'],weight_decay=1e-4)
cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer,modelConfig['epochs'],eta_min=0,last_epoch=-1
)
warmupScheduler = GradualWarmupScheduler(
optimizer,modelConfig['multiplier'],modelConfig['epochs'] // 10,
cosineScheduler
)
trainer = GaussianDiffusionTrainer(
modelConfig['beta_1'],
modelConfig['beta_T'],
modelConfig['T'],
net_model).to(device)
for epoch in range(modelConfig['epochs']):
with tqdm(dataloader,dynamic_ncols=True) as tqdmDataLoader:
for images,_ in tqdmDataLoader:
optimizer.zero_grad()
x_0 = images.to(device)
loss = trainer(x_0).sum() / 1000
loss.backward()
torch.nn.utils.clip_grad_norm_(
net_model.parameters(),modelConfig['grad_clip']
)
optimizer.step()
tqdmDataLoader.set_postfix(ordered_dict={
"epoch": epoch,
"loss: ": loss.item(),
"img shape: ": x_0.shape,
"LR": optimizer.state_dict()['param_groups'][0]["lr"]
})
warmupScheduler.step()
torch.save(net_model,os.path.join(modelConfig['checkpoint_dir'] + f"ckpt_{epoch}.pth"))
def eval(modelConfig:Dict):
with torch.no_grad():
device = torch.device(modelConfig['device'])
model = torch.load(os.path.join(modelConfig['checkpoint_dir'],modelConfig['test_load_weight']),device)
print("Model loaded")
model.eval()
sampler = GaussianDiffusionSampler(
modelConfig['beta_1'], modelConfig['beta_T'],
model,modelConfig['T']
)
noisyImage = torch.randn(
size=[modelConfig['batch_size'],3,32,32],
device=device
)
sampledImgs = sampler(noisyImage)
sampledImgs = sampledImgs * 0.5 + 0.5
save_image(sampledImgs,
os.path.join(modelConfig['sample_dir'],modelConfig['sampledImgName']),
nrow = modelConfig['nrow']
)