import os from typing import Dict import torch import torch.optim as optim from tqdm import tqdm from torch.utils.data import DataLoader from torchvision import transforms from torchvision.datasets import CIFAR10 from torchvision.utils import save_image from diffusion import GaussianDiffusionSampler, GaussianDiffusionTrainer from model import UNet from scheduler import GradualWarmupScheduler def train(modelConfig: Dict): device = torch.device(modelConfig['device']) dataset = CIFAR10( "./",train=True,download=True,transform=transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) ]) ) dataloader = DataLoader(dataset,batch_size=modelConfig['batch_size'],shuffle=True, num_workers=4,drop_last=True,pin_memory=True) net_model = UNet(modelConfig['T'],modelConfig['channel'],modelConfig['ch_mult'],modelConfig['attn'], modelConfig['num_res_blocks'],modelConfig['dropout']) optimizer = optim.AdamW(net_model.parameters(),modelConfig['lr'],weight_decay=1e-4) cosineScheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer,modelConfig['epochs'],eta_min=0,last_epoch=-1 ) warmupScheduler = GradualWarmupScheduler( optimizer,modelConfig['multiplier'],modelConfig['epochs'] // 10, cosineScheduler ) trainer = GaussianDiffusionTrainer( modelConfig['beta_1'], modelConfig['beta_T'], modelConfig['T'], net_model).to(device) for epoch in range(modelConfig['epochs']): with tqdm(dataloader,dynamic_ncols=True) as tqdmDataLoader: for images,_ in tqdmDataLoader: optimizer.zero_grad() x_0 = images.to(device) loss = trainer(x_0).sum() / 1000 loss.backward() torch.nn.utils.clip_grad_norm_( net_model.parameters(),modelConfig['grad_clip'] ) optimizer.step() tqdmDataLoader.set_postfix(ordered_dict={ "epoch": epoch, "loss: ": loss.item(), "img shape: ": x_0.shape, "LR": optimizer.state_dict()['param_groups'][0]["lr"] }) warmupScheduler.step() torch.save(net_model,os.path.join(modelConfig['checkpoint_dir'] + f"ckpt_{epoch}.pth")) def eval(modelConfig:Dict): with torch.no_grad(): device = torch.device(modelConfig['device']) model = torch.load(os.path.join(modelConfig['checkpoint_dir'],modelConfig['test_load_weight']),device) print("Model loaded") model.eval() sampler = GaussianDiffusionSampler( modelConfig['beta_1'], modelConfig['beta_T'], model,modelConfig['T'] ) noisyImage = torch.randn( size=[modelConfig['batch_size'],3,32,32], device=device ) sampledImgs = sampler(noisyImage) sampledImgs = sampledImgs * 0.5 + 0.5 save_image(sampledImgs, os.path.join(modelConfig['sample_dir'],modelConfig['sampledImgName']), nrow = modelConfig['nrow'] )