| import os |
| from typing import Dict |
| import torch |
| import torch.optim as optim |
| from tqdm import tqdm |
| from torch.utils.data import DataLoader |
| from torchvision import transforms |
| from torchvision.datasets import CIFAR10 |
| from torchvision.utils import save_image |
| from diffusion import GaussianDiffusionSampler, GaussianDiffusionTrainer |
| from model import UNet |
| from scheduler import GradualWarmupScheduler |
|
|
| def train(modelConfig: Dict): |
| device = torch.device(modelConfig['device']) |
|
|
| dataset = CIFAR10( |
| "./",train=True,download=True,transform=transforms.Compose([ |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) |
| ]) |
| ) |
|
|
| dataloader = DataLoader(dataset,batch_size=modelConfig['batch_size'],shuffle=True, num_workers=4,drop_last=True,pin_memory=True) |
|
|
| net_model = UNet(modelConfig['T'],modelConfig['channel'],modelConfig['ch_mult'],modelConfig['attn'], |
| modelConfig['num_res_blocks'],modelConfig['dropout']) |
| optimizer = optim.AdamW(net_model.parameters(),modelConfig['lr'],weight_decay=1e-4) |
| cosineScheduler = optim.lr_scheduler.CosineAnnealingLR( |
| optimizer,modelConfig['epochs'],eta_min=0,last_epoch=-1 |
| ) |
|
|
| warmupScheduler = GradualWarmupScheduler( |
| optimizer,modelConfig['multiplier'],modelConfig['epochs'] // 10, |
| cosineScheduler |
| ) |
|
|
| trainer = GaussianDiffusionTrainer( |
| modelConfig['beta_1'], |
| modelConfig['beta_T'], |
| modelConfig['T'], |
| net_model).to(device) |
| |
| for epoch in range(modelConfig['epochs']): |
| with tqdm(dataloader,dynamic_ncols=True) as tqdmDataLoader: |
| for images,_ in tqdmDataLoader: |
| optimizer.zero_grad() |
| x_0 = images.to(device) |
| loss = trainer(x_0).sum() / 1000 |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_( |
| net_model.parameters(),modelConfig['grad_clip'] |
| ) |
| optimizer.step() |
| tqdmDataLoader.set_postfix(ordered_dict={ |
| "epoch": epoch, |
| "loss: ": loss.item(), |
| "img shape: ": x_0.shape, |
| "LR": optimizer.state_dict()['param_groups'][0]["lr"] |
| }) |
| warmupScheduler.step() |
| torch.save(net_model,os.path.join(modelConfig['checkpoint_dir'] + f"ckpt_{epoch}.pth")) |
|
|
|
|
| def eval(modelConfig:Dict): |
| with torch.no_grad(): |
| device = torch.device(modelConfig['device']) |
| model = torch.load(os.path.join(modelConfig['checkpoint_dir'],modelConfig['test_load_weight']),device) |
| print("Model loaded") |
| model.eval() |
| sampler = GaussianDiffusionSampler( |
| modelConfig['beta_1'], modelConfig['beta_T'], |
| model,modelConfig['T'] |
| ) |
| noisyImage = torch.randn( |
| size=[modelConfig['batch_size'],3,32,32], |
| device=device |
| ) |
| sampledImgs = sampler(noisyImage) |
| sampledImgs = sampledImgs * 0.5 + 0.5 |
| save_image(sampledImgs, |
| os.path.join(modelConfig['sample_dir'],modelConfig['sampledImgName']), |
| nrow = modelConfig['nrow'] |
| ) |
|
|
|
|